In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as functional
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = nn.Sequential(  
            #0
            #1x28x28
            # 1
            nn.Conv2d(1, 32, kernel_size=3, padding=1), 
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),   #32x14x14
            # 2
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),     #64x14x14
            # 3
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),    #128x7x7
            #4
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU()    #128x7x7
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128*7*7, 512),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(256, 35)
        )
        
    def forward(self, x):
        x = x.view(-1, 1, 28, 28)  # Reshape input to (batch_size, 1, 28, 28)
        x = self.feature_extractor(x)
        x = self.classifier(x)
        x = x.squeeze()  #W CHUJ WAZNE, ZEBY TO BYLO ORYGINALNIE
        return x

In [4]:
def sharpen_image(tensor, upper_treshold=0.75, lower_treshold=0.25):
    # Ustawiamy piksele < 0.5 na 0
    tensor[tensor > upper_treshold] = 1
    tensor[tensor < lower_treshold] = 0
    #tensor[(tensor < upper_treshold) & (tensor > lower_treshold)] = 0.5
    return tensor

In [5]:
class OneHotImageFolder(ImageFolder):
    def __init__(self, root, transform=None, num_classes=None):
        super().__init__(root, transform=transform)
        self.num_classes = num_classes if num_classes else len(self.classes)  # Auto-detect class count

    def __getitem__(self, idx):
        image, label = super().__getitem__(idx)
        one_hot_label = functional.one_hot(torch.tensor(label), num_classes=self.num_classes).float().to(device)
        return image, one_hot_label  # Return one-hot encoded labels


In [6]:
train_transforms=transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(45),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.to(device)),
    transforms.Resize((28,28)),
    transforms.Lambda(lambda x: 1 - x),
    transforms.Lambda(lambda x: sharpen_image(x, 0.7, 0.2))
])

In [7]:
test_transforms=transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.to(device)),
    transforms.Resize((28,28)),
    transforms.Lambda(lambda x: 1-x),
    transforms.Lambda(lambda x: sharpen_image(x, 0.7, 0.2))
])

In [8]:
user_transforms=transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.to(device)),
    transforms.Lambda(lambda x: 1-x),
    nn.MaxPool2d(kernel_size=2),
    transforms.Lambda(lambda x: sharpen_image(x, 0.7, 0.4)),
    #nn.MaxPool2d(kernel_size=2),
    transforms.Resize((28,28)),
    transforms.Lambda(lambda x: 1-x),
    transforms.Lambda(lambda x: sharpen_image(x, 0.6, 0.4)),
])

In [9]:
dataset_train = OneHotImageFolder(
    "images_train",
    transform = train_transforms,
)
dataset_test = OneHotImageFolder(
    "images_test",
    transform = test_transforms,
)
dataset_user = OneHotImageFolder(
    "images_user",
    transform = user_transforms,
)

In [10]:
dataloader_train = DataLoader(dataset_train, batch_size=256, shuffle=True)
dataloader_test = DataLoader(dataset_train, batch_size=128, shuffle=True)

In [11]:
class_to_idx = dataset_train.class_to_idx
idx_to_class = {v:k for k,v in class_to_idx.items()}

In [12]:
def show_train(index):
    image, label = dataset_train[index]
    img_arr = image.detach().cpu().numpy().squeeze()
    plt.imshow(img_arr, cmap='gray')
    plt.show()

In [13]:
def show_test(index):
    image, label = dataset_test[index]
    img_arr = image.detach().cpu().numpy().squeeze()
    plt.imshow(img_arr, cmap='gray')
    plt.show()
    print(idx_to_class[np.argmax(label.cpu().numpy())])

In [14]:
def show_user():
    for i in range(len(dataset_user)):
        i = i * 1
        image, label = dataset_user[i]
        img_arr = image.detach().cpu().numpy().squeeze()
        plt.imshow(img_arr, cmap='gray')
        plt.show()
        output = net(image)
        print( idx_to_class [ np.argmax ( output.detach().cpu().numpy().squeeze() ) ] )

In [15]:
net = Net().to(device)

In [16]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.005)

In [17]:
num_epochs = 15
for epoch in range(num_epochs):
    net.train()
    train_loss = 0.0
    for batch_idx, (features, labels) in enumerate(dataloader_train):
        optimizer.zero_grad()
        outputs = net(features)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        
        if batch_idx%1000 == 0:
            print(f"epoch [{epoch+1}/{num_epochs}]:    batch:[{batch_idx}/{len(dataloader_train)}]  loss: {train_loss/(batch_idx+1)}")
            
    print('--------------------------------------------------------------------------------------------------------------------------------')
    print(f"epoch [{epoch+1}/{num_epochs}]: loss: {train_loss/len(dataloader_test):.6f}")
    print('--------------------------------------------------------------------------------------------------------------------------------')

epoch [1/15]:    batch:[0/2735]  loss: 3.5646467208862305


KeyboardInterrupt: 