<a href="https://colab.research.google.com/github/Aditib2409/PyTorch_ResNets_FashionMNIST/blob/main/FashionMNIST_ResNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import math
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import sklearn
import matplotlib.pyplot as plt
import itertools
from torch.utils.tensorboard import SummaryWriter


torch.set_printoptions(linewidth=120)

# Performing the ETL Process 

data_set = torchvision.datasets.FashionMNIST(
	root = './data/FashionMNIST'
	,train=True
	,download=True
	,transform=transforms.Compose([
		transforms.ToTensor()
	])
)

train_set, test_set = torch.utils.data.random_split(data_set, [50000, 10000])

batch_size = 100
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)


class FashionMNISTResNet(nn.Module):
	def __init__(self, in_channels=1):
		super(FashionMNISTResNet, self).__init__()
		# loading a pretrained model
		self.model = torchvision.models.resnet50(pretrained=True)
		# changing the input color channels to 1 since original resnet has 3 channels for RGB
		self.model.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False)
		#change the output layer to 10 ckasses as the original resnet has 1000 classes
		num_ftrs = self.model.fc.in_features
		self.model.fc = nn.Linear(num_ftrs, 10)

	def forward(self, t):		
		return self.model(t)
	

def get_num_correct_predictions(preds, labels):
	return preds.argmax(dim=1).eq(labels).sum().item()
#batch_size_list = [100, 1000, 10000]
lr_list = [0.001]
lr = 0.001
acc = []
fin_acc = []
ep = []

network = FashionMNISTResNet()
images, labels = next(iter(train_loader))
grid = torchvision.utils.make_grid(images)

comment = f' batch_size={batch_size} lr={lr}'
tb = SummaryWriter(comment=comment)
tb.add_image('images', grid)
tb.add_graph(network, images)

for lr in lr_list:

		network = FashionMNISTResNet()

		optimizer = optim.Adam(network.parameters(), lr=lr)
	
		for epoch in range(9):
		
			total_loss = 0
			total_correct = 0
			total_test_loss = 0
			total_test_correct = 0
			ep.append(epoch)
			for batch in train_loader:
				images, labels = batch 
				preds = network(images) 
				loss = F.cross_entropy(preds, labels)
				optimizer.zero_grad()
				loss.backward() 
				optimizer.step()

				total_loss += loss.item() * batch_size 
				total_correct += get_num_correct_predictions(preds, labels)
				
				tb.add_scalar('Loss', total_loss, epoch)
				tb.add_scalar('Number Correct', total_correct, epoch)
				tb.add_scalar('Accuracy', total_correct / len(train_set), epoch)

				for name, weight in network.named_parameters():
					tb.add_histogram(name, weight, epoch)
					tb.add_histogram(f'{name}.grad', weight.grad, epoch)
				
			for test_batch in test_loader:
				test_images, test_labels = test_batch
				test_preds = network(test_images)
		
				total_test_correct += get_num_correct_predictions(test_preds, test_labels)

				tb.add_scalar('Number of Correct Test Cases', total_test_correct, epoch)
				tb.add_scalar('Test Accuracy', total_test_correct / len(test_set), epoch)
				"""
				for name, weight in network.named_parameters():
					tb.add_histogram(name, weight, epoch)
					tb.add_histogram(f'{name}.grad', weight.grad, epoch)
				"""
			accuracy = total_correct / len(train_set) * 100
			acc.append(accuracy)
			test_accuracy = total_test_correct / len(test_set) * 100
			print("epoch:", epoch, "total_correct:", total_correct, "loss:", total_loss, "Training Accuracy:", accuracy,'%', "Test Accuracy:", test_accuracy,'%')
		fin_acc.append(accuracy)
		print("Learning_rate:", lr, "Accuracy:", accuracy)





Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/FashionMNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]

Extracting ./data/FashionMNIST/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/FashionMNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]

Extracting ./data/FashionMNIST/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]

Extracting ./data/FashionMNIST/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/5148 [00:00<?, ?it/s]

Extracting ./data/FashionMNIST/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/FashionMNIST/raw



Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

epoch: 0 total_correct: 41474 loss: 23894.590847194195 Training Accuracy: 82.948 % Test Accuracy: 87.48 %
epoch: 1 total_correct: 44316 loss: 15605.773413926363 Training Accuracy: 88.632 % Test Accuracy: 87.92 %
epoch: 2 total_correct: 45131 loss: 13515.896312147379 Training Accuracy: 90.262 % Test Accuracy: 89.98 %
epoch: 3 total_correct: 45545 loss: 12213.863567262888 Training Accuracy: 91.09 % Test Accuracy: 90.2 %
epoch: 4 total_correct: 45882 loss: 11342.022790014744 Training Accuracy: 91.764 % Test Accuracy: 90.49000000000001 %
epoch: 5 total_correct: 45402 loss: 12552.864323928952 Training Accuracy: 90.804 % Test Accuracy: 89.14 %
epoch: 6 total_correct: 45788 loss: 11340.712983161211 Training Accuracy: 91.57600000000001 % Test Accuracy: 90.86 %
epoch: 7 total_correct: 45584 loss: 12509.223701804876 Training Accuracy: 91.168 % Test Accuracy: 89.05999999999999 %
epoch: 8 total_correct: 45936 loss: 11029.030633717775 Training Accuracy: 91.872 % Test Accuracy: 90.94 %
