In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
import random 
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import pickle
from PIL import Image
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from torchvision.models import resnet50, ResNet50_Weights
from tqdm.notebook import tqdm

from helpers_training import *

Using device: mps


### General Setup

In [2]:
import platform
if platform.system() == 'Darwin':
    DATA_PATH = "/Users/maltegenschow/Documents/Uni/Thesis/Data.nosync"
elif platform.system() == 'Linux':
    DATA_PATH = "/pfs/work7/workspace/scratch/tu_zxmav84-thesis/Data.nosync"

In [3]:
set_seed(42)
device = set_device()

Using device: mps


### Data Setup

In [4]:
target_feature = 'category'
retrain = False

In [5]:
df =  pd.read_json(f"{DATA_PATH}/Zalando_Germany_Dataset/dresses/metadata/dresses_metadata.json").T.reset_index().rename(columns={'index': 'sku'})
#df = df.sample(4000)
df, id2label, label2id = prepare_data(df, target_feature)

In [6]:
train_transform = transforms.Compose([
    transforms.Resize((224, 224), interpolation=InterpolationMode.BILINEAR),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224), interpolation=InterpolationMode.BILINEAR),
    transforms.ToTensor(),
    #ransforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset, test_dataset = get_datasets(df, 
                f"{DATA_PATH}/Zalando_Germany_Dataset/dresses/images/square_images/", 
                train_transform, 
                test_transform)

### Model Setup

In [7]:
def setup_model(num_labels, frozen_backbone=True):
    weights = ResNet50_Weights.IMAGENET1K_V2
    model = resnet50(weights=weights)

    if frozen_backbone:
        for param in model.parameters():
            param.requires_grad = False

    # Change last layer: 
    num_labels = len(id2label)
    model.fc = nn.Linear(model.fc.in_features, num_labels)
    model = model.to(device)
    return model

In [8]:
model = setup_model(len(id2label), frozen_backbone=True)


# Define Hyperparameters
BATCH_SIZE = 16
LR = 0.001
EPOCHS = 3

# Define Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)

save_dir = f"{DATA_PATH}/Models/Assessor/ResNet/"
train_model(model, train_dataset, test_dataset, criterion, optimizer, EPOCHS, BATCH_SIZE, report_interval=10, eval_every=1000000, save_dir=save_dir)

Epoch 1:   0%|          | 0/703 [00:00<?, ?it/s]

Batch 10: LOSS within report interval: 1.9085091829299927 | ACCURACY within report interval: 0.325
Batch 20: LOSS within report interval: 1.789617669582367 | ACCURACY within report interval: 0.365625
Batch 30: LOSS within report interval: 1.71024329662323 | ACCURACY within report interval: 0.36875
Batch 40: LOSS within report interval: 1.617199444770813 | ACCURACY within report interval: 0.375
Batch 50: LOSS within report interval: 1.695075273513794 | ACCURACY within report interval: 0.385
Batch 60: LOSS within report interval: 1.723379945755005 | ACCURACY within report interval: 0.38125
Batch 70: LOSS within report interval: 1.7763778448104859 | ACCURACY within report interval: 0.37857142857142856
Batch 80: LOSS within report interval: 1.6570648789405822 | ACCURACY within report interval: 0.38828125
Batch 90: LOSS within report interval: 1.6552448511123656 | ACCURACY within report interval: 0.3875
Batch 100: LOSS within report interval: 1.6525492429733277 | ACCURACY within report inte

Epoch 2:   0%|          | 0/703 [00:00<?, ?it/s]

Batch 10: LOSS within report interval: 1.2888161897659303 | ACCURACY within report interval: 0.50625
Batch 20: LOSS within report interval: 1.3641527533531188 | ACCURACY within report interval: 0.525
Batch 30: LOSS within report interval: 1.3695538759231567 | ACCURACY within report interval: 0.5041666666666667
Batch 40: LOSS within report interval: 1.4748401999473573 | ACCURACY within report interval: 0.4921875
Batch 50: LOSS within report interval: 1.3661333084106446 | ACCURACY within report interval: 0.5075
Batch 60: LOSS within report interval: 1.3317946791648865 | ACCURACY within report interval: 0.509375
Batch 70: LOSS within report interval: 1.397778046131134 | ACCURACY within report interval: 0.5098214285714285
Batch 80: LOSS within report interval: 1.4318916916847229 | ACCURACY within report interval: 0.50390625
Batch 90: LOSS within report interval: 1.4020453572273255 | ACCURACY within report interval: 0.5
Batch 100: LOSS within report interval: 1.2933443069458008 | ACCURACY w

Epoch 3:   0%|          | 0/703 [00:00<?, ?it/s]

Batch 10: LOSS within report interval: 1.412113881111145 | ACCURACY within report interval: 0.5
Batch 20: LOSS within report interval: 1.4006840705871582 | ACCURACY within report interval: 0.46875
Batch 30: LOSS within report interval: 1.3991490602493286 | ACCURACY within report interval: 0.47708333333333336
Batch 40: LOSS within report interval: 1.2616963028907775 | ACCURACY within report interval: 0.5
Batch 50: LOSS within report interval: 1.3619627714157105 | ACCURACY within report interval: 0.50125
Batch 60: LOSS within report interval: 1.436678123474121 | ACCURACY within report interval: 0.49895833333333334
Batch 70: LOSS within report interval: 1.4003711223602295 | ACCURACY within report interval: 0.4901785714285714
Batch 80: LOSS within report interval: 1.3139843463897705 | ACCURACY within report interval: 0.4984375
Batch 90: LOSS within report interval: 1.517417049407959 | ACCURACY within report interval: 0.4888888888888889
Batch 100: LOSS within report interval: 1.461210930347

### Evaluate Model

In [9]:
save_path = f"{save_dir}model_epoch_3.pt"
model = torch.load(save_path)
model = model.to(device)

In [10]:
from torch.nn.functional import softmax
for i in tqdm(df.index, total=len(df.index)):
    sku = df.loc[i]['sku']
    img = Image.open(f"{DATA_PATH}/Zalando_Germany_Dataset/dresses/images/square_images/{sku}.jpg")
    img = test_transform(img).unsqueeze(0).to(device)
    label = df.loc[i]['label']
    with torch.no_grad():
        output = model(img)
        prob, predicted = torch.max(softmax(output.data, 1), 1)
        df.loc[i, 'predicted_label'] = predicted.item()
        df.loc[i, 'predicted_prob'] = prob.item()

df['predicted_category'] = df.predicted_label.map(id2label)

  0%|          | 0/14060 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
train_acc = evaluate(model, DataLoader(train_dataset, batch_size=32, shuffle=False))
test_acc = evaluate(model, DataLoader(test_dataset, batch_size=32, shuffle=False))
print(f"Train Accuracy: {np.round(train_acc*100, 2)}%, Test Accuracy: {np.round(test_acc*100, 2)}%")

In [None]:
from sklearn.metrics import accuracy_score
overall_accuracy = accuracy_score(df['label'], df['predicted_label'])
print(f"Overall Accuracy: {overall_accuracy:.2%}")

In [None]:
# Function to calculate accuracy for a group
def group_accuracy(group):
    return accuracy_score(group['label'], group['predicted_label'])

# Calculate accuracy for each category
group_accuracy = df.groupby(target_feature).apply(group_accuracy).sort_values(ascending=False)

# Print the accuracy for each category
print("Accuracy by Group:")
print(group_accuracy)

In [None]:
plot_data = pd.DataFrame(group_accuracy).reset_index().rename(columns={0: 'accuracy'}).merge(df[target_feature].value_counts().reset_index())
# Plot accuracy against number of samples and name each dot
plt.figure()
import seaborn as sns
sns.scatterplot(data=plot_data, x='accuracy', y='count')
for i in range(plot_data.shape[0]):
    plt.text(plot_data.accuracy[i], plot_data['count'][i], plot_data[target_feature][i], fontsize=9)

In [None]:
# Print confusion matrix
confusion_matrix = pd.crosstab(df['predicted_category'], df[target_feature], rownames=['Predicted'], colnames=['Actual'])
confusion_matrix

In [None]:
# Lets see some misclassified examples
root_path = f"{DATA_PATH}/Zalando_Germany_Dataset/dresses/images/square_images/"
misclassified = df[df['label'] != df['predicted_label']]
misclassified = misclassified.sample(1)
for i, row in misclassified.iterrows():
    img_path = f"{root_path}/{row['sku']}.jpg"
    img = Image.open(img_path)
    plt.figure()
    plt.imshow(img)
    plt.title(f"Actual: {row[target_feature]}\nPredicted: {row['predicted_category']}")
    plt.axis('off')

### Weights and Biases

In [None]:
import wandb

In [None]:
# def train_epoch(model, train_loader, test_loader, criterion, optimizer, epoch_num, report_interval=2):
#     model.train()
#     running_loss = 0.0
#     running_corrects = 0
#     running_total = 0

#     # Configure how often to evaluate based on the total images shown
#     eval_steps = max(1, 1000 // train_loader.batch_size)  

#     for i, data in enumerate(train_loader):
#         images, labels = data
#         images, labels = images.to(device), labels.to(device)
#         optimizer.zero_grad()
#         outputs = model(images)
#         loss = criterion(outputs, labels)
#         loss.backward()
#         optimizer.step()
#         running_loss += loss.item()
#         _, predicted = torch.max(outputs, 1)
#         running_corrects += (predicted == labels).sum().item()
#         running_total += labels.size(0)

#         if i % eval_steps == eval_steps-1:
#             test_acc = evaluate(model, test_loader)
#             wandb.log({'test_accuracy': test_acc})
#             print(f"Batch {i+1}: Running LOSS: {running_loss / eval_steps} | Running ACCURACY: {running_corrects / running_total} | TEST ACCURACY: {test_acc}")
#             running_loss = 0.0
#             running_corrects = 0
#             running_total = 0

#     train_acc = evaluate(model, train_loader)
#     wandb.log({'train_accuracy': train_acc, 'final_test_accuracy': evaluate(model, test_loader)})

# def evaluate(model, data_loader):
#     model.eval()
#     correct = 0
#     total = 0
#     with torch.no_grad():
#         for data in data_loader:
#             images, labels = data
#             images, labels = images.to(device), labels.to(device)
#             outputs = model(images)
#             _, predicted = torch.max(outputs.data, 1)
#             total += labels.size(0)
#             correct += (predicted == labels).sum().item()
#     return correct / total

# def train_model(config=None):
#     with wandb.init(config=config):
#         config = wandb.config

#         # Define your model, criterion, and optimizer here
#         model = setup_model(len(id2label))
#         criterion = nn.CrossEntropyLoss()
#         optimizer = get_optimizer(model, config.optimizer, config.learning_rate)
        
#         train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
#         test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)

#         for epoch in range(config.epochs):
#             train_epoch(model, train_loader, test_loader, criterion, optimizer, epoch + 1)

# def get_optimizer(model, optimizer_name, learning_rate):
#     if optimizer_name == 'adam':
#         return torch.optim.Adam(model.parameters(), lr=learning_rate)
#     elif optimizer_name == 'sgd':
#         return torch.optim.SGD(model.parameters(), lr=learning_rate)

In [None]:
# wandb.init(project='model_finetuning')

# # Sweep configuration
# sweep_config = {
#     'method': 'random',
#     'metric': {
#         'name': 'final_test_accuracy',
#         'goal': 'maximize'
#     },
#     'parameters': {
#         'batch_size': {
#             'values': [8, 16, 32, 64]
#         },
#         'learning_rate': {
#             'values': [0.001, 0.0001, 0.01]
#         },
#         'optimizer': {
#             'values': ['adam', 'sgd']
#         },
#         'epochs': {
#             'value':  1 # Adjust number of epochs for each run
#         }
#     }
# }

# sweep_id = wandb.sweep(sweep_config, project='resnet_category_tuning')
# wandb.agent(sweep_id, train_model)
