<a href="https://colab.research.google.com/github/Vidhan-152/Fashion-MNIST-Classifier/blob/main/FineTuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

In [None]:
torch.manual_seed(42)

In [None]:
# GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
df = pd.read_csv('fashion-mnist_train.csv', on_bad_lines='skip')
df.head()


In [None]:
df.shape

In [None]:
fig, axis = plt.subplots(4,4, figsize=(10,10))
fig.suptitle('First 16 images in the dataset')

for i , ax in enumerate(axis.flat):
  img = df.iloc[i, 1:].values.reshape(28,28)
  ax.imshow(img)
  ax.axis('off')
  ax.set_title(df.iloc[i,0])

plt.tight_layout()
plt.show()


In [None]:
X = df.iloc[:,1:].values
y = df.iloc[:,0].values


In [None]:
X

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X,y, test_size=0.2, random_state=42)

In [None]:
# transformation
from torchvision import transforms

custom_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
])

In [None]:
from PIL import Image
import numpy as np

class CustomDataset(Dataset):
  def __init__(self, features, labels, transform):
    self.features = features
    self.labels = labels
    self.transform = transform

  def __len__(self):
    return len(self.features)

  def __getitem__(self,index):
    # Resize (28,28)
    image = self.features[index].reshape(28,28)

    # UINT-8
    image = np.uint8(image)

    # black&white to color
    image = np.stack([image]*3, axis = -1)

    # PIL image
    image = Image.fromarray(image)

    # transformation
    image = self.transform(image)
    label = torch.tensor(self.labels[index], dtype = torch.long)

    # return
    return image, label


In [None]:
# train and test dataset
train_dataset = CustomDataset(X_train, y_train,transform=custom_transform)
test_dataset = CustomDataset(X_test , y_test,transform=custom_transform)

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size = 32, shuffle = True, pin_memory = True)
test_dataloader = DataLoader(test_dataset, batch_size= 32, shuffle = False, pin_memory = True)

In [None]:
# fetch the pretrained model
import torchvision.models as models

vgg16 = models.vgg16(pretrained = True)

In [None]:
vgg16

In [None]:
for param in vgg16.features.parameters():
  param.requires_grad = False

In [None]:
vgg16.classifier = nn.Sequential(
    nn.Linear(25088, 1024),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(1024,512),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(512, 10)
)

In [None]:
vgg16 = vgg16.to(device)

In [None]:
learning_rate = 0.0001
epochs = 12

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vgg16.classifier.parameters(), lr=learning_rate)

In [None]:
for epoch in range(epochs):
    for batch_feature, batch_label in train_dataloader:
        batch_feature = batch_feature.to(device)
        batch_label = batch_label.to(device)

        # Forward pass
        outputs = vgg16(batch_feature)

        # Compute loss
        loss = criterion(outputs, batch_label)

        # Backpropagation and gradient update
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()




In [None]:
vgg16.eval()
total = 0
correct = 0

with torch.no_grad():
    for batch_features, batch_labels in test_dataloader:
        batch_features = batch_features.to(device)
        batch_labels = batch_labels.to(device)

        outputs = vgg16(batch_features)
        _, predicted = torch.max(outputs.data, 1)

        total += batch_labels.size(0)
        correct += (predicted == batch_labels).sum().item()

accuracy = correct / total

In [None]:
accuracy

In [None]:
torch.save(vgg16.state_dict(), "model.pth")