In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.io import decode_image
from torchvision import models
import os
import pandas as pd
import torchvision.transforms.v2 as T
from sklearn.metrics import accuracy_score

In [3]:
class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, f"{self.img_labels.iloc[idx, 0]}.jpg")
        image = decode_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label
    
def transform_target(label):
    brand_to_idx = {
        "adidas": 0,
        "converse": 1,
        "nike": 2,
    }
    return brand_to_idx[label]

transform = T.Compose([
    T.Resize(227),
    T.CenterCrop(227),
    T.ToDtype(torch.float32, scale=True),
])

training_data = CustomImageDataset(annotations_file="data/train/annotations.csv", img_dir="data/train/images", transform=transform,target_transform=transform_target)

testing_data = CustomImageDataset(annotations_file="data/test/annotations.csv", img_dir="data/test/images", transform=transform, target_transform=transform_target)

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(testing_data, batch_size=64, shuffle=False)

In [4]:
model = models.alexnet(weights=models.AlexNet_Weights.DEFAULT)

model.classifier[6] = torch.nn.Linear(in_features=4096, out_features=3)

Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /home/heathcliff/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth


100%|██████████| 233M/233M [01:21<00:00, 2.99MB/s] 


In [5]:
for param in model.parameters():
    param.requires_grad_(False)

In [6]:
model

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

In [7]:
model.classifier.requires_grad_(True)

Sequential(
  (0): Dropout(p=0.5, inplace=False)
  (1): Linear(in_features=9216, out_features=4096, bias=True)
  (2): ReLU(inplace=True)
  (3): Dropout(p=0.5, inplace=False)
  (4): Linear(in_features=4096, out_features=4096, bias=True)
  (5): ReLU(inplace=True)
  (6): Linear(in_features=4096, out_features=3, bias=True)
)

In [8]:
for name, param in model.classifier.named_parameters():
    print(name, param.requires_grad)

1.weight True
1.bias True
4.weight True
4.bias True
6.weight True
6.bias True


In [9]:
if torch.cuda.is_available():
  model.to("cuda")
else:
  model.to("cpu")

In [11]:
from torch import nn

epochs = 20
optim = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)
loss_fn = nn.CrossEntropyLoss()

In [12]:
def train_one_epoch():
    running_loss = 0.
    last_loss = 0.


    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(train_dataloader):
        # Every data instance is an input + label pair
        inputs, labels = data

        # Zero your gradients for every batch!
        optim.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)


        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optim.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            running_loss = 0.

    return last_loss

In [13]:
for epoch in range(epochs):
  print('EPOCH {}:'.format(epoch + 1))

  # Make sure gradient tracking is on, and do a pass over the data
  model.train(True)
  avg_loss = train_one_epoch()


  running_vloss = 0.0
  # Set the model to evaluation mode, disabling dropout and using population
  # statistics for batch normalization.
  model.eval()

  preds = []
  true = []

  # Disable gradient computation and reduce memory consumption.
  with torch.no_grad():
      for i, vdata in enumerate(test_dataloader):
          vinputs, vlabels = vdata
          voutputs = model(vinputs)
          preds = preds + list(torch.max(voutputs, 1).indices.cpu().numpy())
          true = true + list(vlabels.cpu().numpy())
          vloss = loss_fn(voutputs, vlabels)
          running_vloss += vloss.item()
  acc = accuracy_score(true, preds)
  avg_vloss = running_vloss / (i + 1)
  print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
  print('ACCURACY valid {}'.format(acc))

EPOCH 1:
LOSS train 0.0 valid 0.8818612098693848
ACCURACY valid 0.6371681415929203
EPOCH 2:
LOSS train 0.0 valid 0.6990245878696442
ACCURACY valid 0.7610619469026548
EPOCH 3:
LOSS train 0.0 valid 1.0554738119244576
ACCURACY valid 0.6194690265486725
EPOCH 4:
LOSS train 0.0 valid 0.7677903324365616
ACCURACY valid 0.7256637168141593
EPOCH 5:
LOSS train 0.0 valid 0.6266226321458817
ACCURACY valid 0.7787610619469026
EPOCH 6:
LOSS train 0.0 valid 0.7162491828203201
ACCURACY valid 0.7522123893805309
EPOCH 7:
LOSS train 0.0 valid 0.6060633957386017
ACCURACY valid 0.8230088495575221
EPOCH 8:
LOSS train 0.0 valid 0.7711837887763977
ACCURACY valid 0.8053097345132744
EPOCH 9:
LOSS train 0.0 valid 1.0648573189973831
ACCURACY valid 0.7522123893805309
EPOCH 10:
LOSS train 0.0 valid 0.777723491191864
ACCURACY valid 0.8141592920353983
EPOCH 11:
LOSS train 0.0 valid 0.990197092294693
ACCURACY valid 0.7699115044247787
EPOCH 12:
LOSS train 0.0 valid 0.8498388528823853
ACCURACY valid 0.7876106194690266
EPO

In [14]:
torch.save(model, 'model.pth')