In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms, models
from torch.utils.data import DataLoader, random_split

In [2]:
torch.cuda.is_available()

True

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [4]:
model = models.resnet18(pretrained=True)
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3)
model.fc = nn.Linear(model.fc.in_features, 10)
model = model.to(device)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 61.7MB/s]


In [5]:
transform = transforms.ToTensor()

trainset = torchvision.datasets.FashionMNIST(root='./data', train=True,
                                             download=True, transform=transform)
testset = torchvision.datasets.FashionMNIST(root='./data', train=False,
                                            download=True, transform=transform)

# Split the training set for validation
train_size = 50000
val_size = 10000
train_dataset, val_dataset = random_split(trainset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
test_loader = DataLoader(testset, batch_size=8, shuffle=False)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26.4M/26.4M [00:02<00:00, 12.8MB/s]


Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29.5k/29.5k [00:00<00:00, 191kB/s]


Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4.42M/4.42M [00:01<00:00, 3.64MB/s]


Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5.15k/5.15k [00:00<00:00, 22.3MB/s]

Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw






In [6]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [7]:
def train_one_epoch():
    model.train(True)
    running_loss = 0.0
    running_accuracy = 0.0

    for batch_index, data in enumerate(train_loader):
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()
        outputs = model(inputs)

        correct = torch.sum(labels == torch.argmax(outputs, dim=1)).item()
        running_accuracy += correct / len(inputs)

        loss = criterion(outputs, labels)
        running_loss += loss.item()
        loss.backward()
        optimizer.step()

        if batch_index % 500 == 499:
            avg_loss_across_batches = running_loss / 500
            avg_acc_across_batches = (running_accuracy / 500) * 100
            print(f'Batch {batch_index+1}, Loss: {avg_loss_across_batches:.3f}, Accuracy: {avg_acc_across_batches:.1f}%')
            running_loss = 0.0
            running_accuracy = 0.0
    print()

In [8]:
def validate_one_epoch():
    model.train(False)
    running_loss = 0.0
    running_accuracy = 0.0

    for data in val_loader:
        inputs, labels = data[0].to(device), data[1].to(device)
        with torch.no_grad():
            outputs = model(inputs)
            correct = torch.sum(labels == torch.argmax(outputs, dim=1)).item()
            running_accuracy += correct / len(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item()

    avg_loss_across_batches = running_loss / len(val_loader)
    avg_acc_across_batches = (running_accuracy / len(val_loader)) * 100

    print(f'Val Loss: {avg_loss_across_batches:.3f}, Val Accuracy: {avg_acc_across_batches:.1f}%')
    print('***************************************************')
    print()

In [9]:
num_epochs = 10
for epoch_index in range(num_epochs):
    print(f'Epoch: {epoch_index + 1}\n')
    train_one_epoch()
    validate_one_epoch()

print('Finished Training')

Epoch: 1

Batch 500, Loss: 1.425, Accuracy: 51.1%
Batch 1000, Loss: 1.072, Accuracy: 62.7%
Batch 1500, Loss: 0.872, Accuracy: 70.3%
Batch 2000, Loss: 0.777, Accuracy: 72.9%
Batch 2500, Loss: 0.722, Accuracy: 75.1%
Batch 3000, Loss: 0.671, Accuracy: 77.0%
Batch 3500, Loss: 0.638, Accuracy: 77.5%
Batch 4000, Loss: 0.585, Accuracy: 80.4%
Batch 4500, Loss: 0.590, Accuracy: 79.8%
Batch 5000, Loss: 0.549, Accuracy: 81.1%
Batch 5500, Loss: 0.518, Accuracy: 82.4%
Batch 6000, Loss: 0.531, Accuracy: 82.5%

Val Loss: 0.412, Val Accuracy: 85.1%
***************************************************

Epoch: 2

Batch 500, Loss: 0.467, Accuracy: 84.4%
Batch 1000, Loss: 0.469, Accuracy: 83.0%
Batch 1500, Loss: 0.507, Accuracy: 83.7%
Batch 2000, Loss: 0.462, Accuracy: 84.1%
Batch 2500, Loss: 0.462, Accuracy: 84.0%
Batch 3000, Loss: 0.443, Accuracy: 85.0%
Batch 3500, Loss: 0.414, Accuracy: 85.5%
Batch 4000, Loss: 0.410, Accuracy: 86.0%
Batch 4500, Loss: 0.404, Accuracy: 86.2%
Batch 5000, Loss: 0.412, Accur