In [32]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from torchvision  import models

In [34]:
device=torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(device)


mps


In [35]:
#Augmentation
transform = transforms.Compose([
	transforms.RandomRotation(10),
	transforms.RandomHorizontalFlip(),
	transforms.ToTensor(),
	transforms.Normalize((0.5,), (0.5,))
    

])

In [36]:
train_dataset=torchvision.datasets.FashionMNIST(root='data',download=False,train=True,transform=transform)
test_dataset=torchvision.datasets.FashionMNIST(root="data",download=False,train=False,transform=transforms.ToTensor())

In [37]:
train_loader=DataLoader(train_dataset,batch_size=64,shuffle=True)
test_loader=DataLoader(test_dataset,batch_size=64,shuffle=False)

In [38]:
#Loading the Resnet-18 weights
model=models.resnet18(weights='ResNet18_Weights.DEFAULT')
#Adjusting for input grey scale
model.conv1=nn.Conv2d(in_channels=1,out_channels=64,kernel_size=3,stride=1,padding=1,bias=False)
model.bn1=nn.BatchNorm2d(64)
model.fc=nn.Linear(model.fc.in_features,10)#as there are 10 classes in FashionMNIST
model=model.to(device)


In [39]:
#loss
criterion=nn.CrossEntropyLoss()
optimizer=optim.Adam(model.parameters(),lr=0.001)


In [24]:
#Training loop
num_epochs=10
for epoch in range(num_epochs):
    model.train()
    total_loss=0
    for images, labels in train_loader:
        images,labels=images.to(device),labels.to(device)

        optimizer.zero_grad()
        outputs=model(images)
        loss=criterion(outputs,labels)
        loss.backward()
        optimizer.step()

        total_loss+=loss.item()

    print(f"Epoch: {epoch}, Loss: {total_loss/len(train_loader)}")

Epoch: 0, Loss: 0.444078264698418
Epoch: 1, Loss: 0.30992521453640864
Epoch: 2, Loss: 0.2701185352639603
Epoch: 3, Loss: 0.24816758026764082
Epoch: 4, Loss: 0.23090379310251552
Epoch: 5, Loss: 0.20979964488080696
Epoch: 6, Loss: 0.20438229590892665
Epoch: 7, Loss: 0.18496168463198998
Epoch: 8, Loss: 0.17587674217326427
Epoch: 9, Loss: 0.1695203005091976


In [26]:
print(device)

mps


In [25]:
#Testing loop
model.eval()
correct=0
total=0

with torch.no_grad():
    for images, labels in test_loader:
        images,labels=images.to(device),labels.to(device)
        outputs=model(images)
        _,predicted=torch.max(outputs.data,1)
        total+=labels.size(0)
        correct+=(predicted==labels).sum().item()
accuracy=correct*100/total
print(f"Accuracy: {accuracy}")

Accuracy: 63.28


In [None]:
#Training loop with realtime plotting 
num_epochs=10
train_losses=[]
test_accuracy=[]
plt.ion()
fig, (ax1,ax2)=plt.subplots(1,2,figsize=(10,5))
for epoch in range(num_epochs):
    model.train()
    total_loss=0
    for images, labels in train_loader:
        images,labels=images.to(device),labels.to(device)

        optimizer.zero_grad()
        outputs=model(images)
        loss=criterion(outputs,labels)
        loss.backward()
        optimizer.step()

        total_loss+=loss.item()

avg_loss=total_loss/len(train_loader)
train_losses.append(avg_loss)
print(f"Epoch: {epoch+1/{num_epochs}}, Loss: {avg_loss:}")

# Evaluate on test set
model.eval()
correct = 0
total = 0
with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
accuracy = correct * 100 / total
test_accuracy.append(accuracy)
print(f"Accuracy: {accuracy}")

# Plotting
ax1.clear()
ax1.plot(range(1, epoch + 2), train_losses, marker='o', linestyle='-', label='Train Loss')
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Loss')
ax1.set_title('Training Loss per Epoch')
ax1.legend()

ax2.clear()
ax2.plot(range(1, epoch + 2),test_accuracy, marker='o', linestyle='-', color='red', label='Test Accuracy')
ax2.set_xlabel('Epochs')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Test Accuracy per Epoch')
ax2.legend()

plt.pause(0.1)
plt.ioff()
plt.show()

