In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms

In [2]:
class CustomDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data[idx]
        label = self.labels[idx]

        if self.transform:
            sample = self.transform(sample)
        return sample, label

In [3]:
# standart format MNIST (N, C, H, W)
X = torch.rand(100, 1, 28, 28) # features
# flatten 100 labels with classes from 0 to 9
y = torch.randint(0, 10, (100,)) # targets

In [4]:
# mean, std for the greychannel, [0,1] -> [-1,1]
dataset = CustomDataset(X, y, transform=transforms.Normalize((0.5,), (0.5)))
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [5]:
class SimpleModel(nn.Module):
    def __init__(self):
        # like nn.Module.__init__(self)
        # sets up stuff like self._parameters dict for trainable weights and biases
        # without, print(list(model.parameters())) not possible
        super(SimpleModel, self).__init__()
        self.flatten = nn.Flatten() # multidim into 1d fully connected
        self.fc = nn.Sequential(
            # in 74 features out 128 fully connected
            nn.Linear(28*28, 128),
            # 128, 128
            nn.ReLU(),
            # 128, 10 output classes
            nn.Linear(128,10)
        )
    
    def forward(self, x):
        # input pics
        x = self.flatten(x)
        # flat arr
        x = self.fc(x)
        # class scores
        return x
    
model = SimpleModel()

In [6]:
# standard criterion for multiclass
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
num_epochs=5
for epoch in range(num_epochs):
    model.train() # activates dropout and batch nromalization if implemented
    running_loss = 0.0
    for inputs, labels in dataloader: # batches
        optimizer.zero_grad() # deletes grads from batch before
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward() # calc weights through backpropagation
        optimizer.step() # update weights with adam
        running_loss += loss.item()
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader):.4f}")

Epoch [1/5], Loss: 2.3155
Epoch [2/5], Loss: 1.9403
Epoch [3/5], Loss: 1.6255
Epoch [4/5], Loss: 1.4896
Epoch [5/5], Loss: 1.2949


In [None]:
model.eval() # deactivates dropout, freezes batch norm, important for consistent prediction
with torch.no_grad():
    correct = 0
    total = 0
    for inputs, labels in dataloader:
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1) # highest score class
        total += labels.size(0)
        correct += (predicted == labels).sum().item() # count correct

    print(f"Accuracy: {100 * correct / total:.2f}%")

Accuracy: 100.00%
