# Imports

In [1]:
import os

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim

# import torchvision.io as io
from PIL import Image
from sklearn.model_selection import train_test_split
from torchvision import models, transforms
from tqdm.auto import tqdm, trange

import wandb
from helpers.dataset import HairStyleDataset

%load_ext autoreload
%autoreload 2

# Load data

In [2]:
data_path = "/content/drive/MyDrive/UCU/CV/HW5_Project/data/"
data_path = "data"
ann_path = os.path.join(data_path, "validation_annotation.csv")
annotation_df = pd.read_csv(ann_path)

In [3]:
train_df, val_df = train_test_split(annotation_df, test_size=0.2)
train_df, val_df = train_df.reset_index(drop=True), val_df.reset_index(drop=True)

In [4]:
any([v is None for v in annotation_df.basestyle.values])

False

In [5]:
# Combine the labels from both dataframes
all_labels = np.r_[train_df.basestyle.values, val_df.basestyle.values]

# Use pd.factorize to get the encodings and indices
img_labels, label_encoding = pd.factorize(all_labels)

# Replace the labels in train_df with their indices
train_df["basestyle"] = img_labels[: train_df.shape[0]]

# Replace the labels in test_df with their indices
val_df["basestyle"] = img_labels[train_df.shape[0] :]

In [6]:
def decode_label(label, label_encoding):
    return label_encoding[label]

# Model init

In [7]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

In [8]:
def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
    # Initialize these variables which will be set in this if statement. Each of these
    #   variables is model specific.
    model_ft = None
    input_size = 0

    if model_name == "resnet":
        """Resnet18"""
        model_ft = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "alexnet":
        """Alexnet"""
        model_ft = models.alexnet(weights=models.AlexNet_Weights.DEFAULT)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "vgg":
        """VGG11_bn"""
        model_ft = models.vgg11_bn(weights=models.VGG11_BN_Weights.DEFAULT)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "squeezenet":
        """Squeezenet"""
        model_ft = models.squeezenet1_0(weights=models.SqueezeNet1_0_Weights.DEFAULT)
        set_parameter_requires_grad(model_ft, feature_extract)
        model_ft.classifier[1] = nn.Conv2d(
            512, num_classes, kernel_size=(1, 1), stride=(1, 1)
        )
        model_ft.num_classes = num_classes
        input_size = 224

    elif model_name == "densenet":
        """Densenet"""
        model_ft = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier.in_features
        model_ft.classifier = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "inception":
        """Inception v3
        Be careful, expects (299,299) sized images and has auxiliary output
        """
        model_ft = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT)
        set_parameter_requires_grad(model_ft, feature_extract)
        # Handle the auxilary net
        num_ftrs = model_ft.AuxLogits.fc.in_features
        model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
        # Handle the primary net
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 299

    else:
        print("Invalid model name, exiting...")
        exit()

    return model_ft, input_size

# Train

In [9]:
def train(model, dataloader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    running_corrects = 0
    for batch_idx, (inputs, labels) in enumerate(tqdm(dataloader, leave=False)):
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        if isinstance(outputs, tuple):  # Inception
            outputs, aux_outputs = outputs
            loss1 = criterion(outputs, labels)
            loss2 = criterion(aux_outputs, labels)
            loss = loss1 + 0.4 * loss2
        else:
            loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        _, preds = torch.max(outputs, 1)
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)
        if batch_idx % 10 == 0:
            wandb.log({"Training Running Loss": loss})
    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = running_corrects.cpu().double() / len(dataloader.dataset)
    wandb.log({"Training Loss": epoch_loss, "Training Accuracy": epoch_acc})

In [10]:
def test(model, dataloader, criterion):
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    for batch_idx, (inputs, labels) in enumerate(tqdm(dataloader, leave=False)):
        inputs = inputs.to(device)
        labels = labels.to(device)
        with torch.no_grad():
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            _, preds = torch.max(outputs, 1)
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = running_corrects.cpu().double() / len(dataloader.dataset)
    wandb.log({"Validation Loss": epoch_loss, "Validation Accuracy": epoch_acc})

In [11]:
num_classes = label_encoding.shape[0]
device = "mps" if torch.backends.mps.is_available() else "cpu"

criterion = nn.CrossEntropyLoss()

model_names = ["resnet", "alexnet", "vgg", "squeezenet", "densenet", "inception"]
model_names = ["resnet", "alexnet"]

for model_name in tqdm(model_names):
    model, input_size = initialize_model(model_name, num_classes, True)
    model.to(device)

    transform = {
        "train": transforms.Compose(
            [
                transforms.RandomResizedCrop(input_size),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ]
        ),
        "val": transforms.Compose(
            [
                transforms.Resize(input_size),
                transforms.CenterCrop(input_size),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ]
        ),
    }

    train_ds = HairStyleDataset(train_df, transform['train'])
    val_ds = HairStyleDataset(val_df, transform['val'])
    train_dl = torch.utils.data.DataLoader(
        train_ds, batch_size=64, shuffle=True, num_workers=4
    )
    val_dl = torch.utils.data.DataLoader(
        val_ds, batch_size=64, shuffle=True, num_workers=4
    )

    wandb.init(project="hairstyle", name=model_name)
    wandb.watch(model)

    params_to_update = [param for param in model.parameters() if param.requires_grad]
    optimizer = optim.SGD(params_to_update, lr=0.001, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

    num_epochs = 25
    # num_epochs = 1

    for epoch in trange(num_epochs, leave=False):
        train(model, train_dl, criterion, optimizer)
        scheduler.step()
        test(model, val_dl, criterion)

    torch.save(model.state_dict(), f'{model_name}_weights.pth')

  0%|          | 0/2 [00:00<?, ?it/s]

[34m[1mwandb[0m: Currently logged in as: [33mastadnik[0m. Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  tensor = flat.histc(bins=self._num_bins, min=tmin, max=tmax)


  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [02:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [02:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [02:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

[34m[1mwandb[0m: Network error (ConnectionError), entering retry loop.
wandb: Network error (ConnectionError), entering retry loop.


0,1
Training Accuracy,▁▅▆▆▇▇▇██████████████████
Training Loss,█▄▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Training Running Loss,█▆▅▄▂▄▅▃▃▁▃▃▄▁▃▂▂▂▃▂▂▄▂▂▃▁▂▁▃▃▃▄▂▄▃▂▄▃▃▃
Validation Accuracy,▁▃▅▆▇▇▇██████████████████
Validation Loss,█▅▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
Training Accuracy,0.52664
Training Loss,1.54281
Training Running Loss,1.62545
Validation Accuracy,0.57103
Validation Loss,1.4009


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01683465833387648, max=1.0)…

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [02:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [02:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [02:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:04<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [00:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]

  0%|          | 0/604 [02:02<?, ?it/s]

  0%|          | 0/151 [00:02<?, ?it/s]