In [1]:
from torchvision import transforms, utils
from torchvision.transforms import Compose, RandomResizedCrop, RandomHorizontalFlip, ColorJitter, ToTensor, Normalize
from transformers import AutoImageProcessor, AutoModelForImageClassification
from glob import glob
from PIL import Image
from tqdm.notebook import tqdm
import torch
import torch.optim as optim
import numpy as np
import os
import logging


from Data_Setup import setup_data_loaders, id2label, label2id

In [2]:
import platform
if platform.system() == 'Darwin':
    DATA_PATH = "/Users/maltegenschow/Documents/Uni/Thesis/Data.nosync"
elif platform.system() == 'Linux':
    DATA_PATH = "/pfs/work7/workspace/scratch/tu_zxmav84-thesis/Data.nosync"

In [3]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

### Initialize Model and Dataloaders

In [4]:
id2label

{0: 'casual_dresses',
 1: 'jersey_dresses',
 2: 'evening_dresses',
 3: 'knitted_dresses',
 4: 'maxi_dresses',
 5: 'shift_dresses',
 6: 'occasion_dresses',
 7: 'denim_dresses'}

In [5]:
# %%
model_name = "facebook/dinov2-base"
processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(model_name, id2label=id2label, label2id=label2id)
model = model.to(device)

mean = processor.image_mean
std = processor.image_std
interpolation = processor.resample

train_transform = Compose([
    #RandomResizedCrop(size=(224, 224), scale=(0.08, 1.0), ratio=(0.75, 1.3333), interpolation=interpolation),
    RandomHorizontalFlip(p=0.5),
    ColorJitter(brightness=(0.6, 1.4), contrast=(0.6, 1.4), saturation=(0.6, 1.4)),
    ToTensor(),
    Normalize(mean=mean, std=std),
])

test_transform = Compose([
    ToTensor(),
    Normalize(mean=mean, std=std),
])

Some weights of Dinov2ForImageClassification were not initialized from the model checkpoint at facebook/dinov2-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### Training Loop

In [6]:
def train_epoch(model, train_loader, loss_fn, optimizer, report_interval=10):
    model.train()
    running_loss = 0.0
    running_total = 0
    running_correct = 0
    for i, data in enumerate(tqdm(train_loader)):
        inputs, labels = data['image'], data['label']
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_fn(outputs.logits, labels)
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            _, predicted = torch.max(outputs.logits, 1)
            running_total += labels.size(0)
            running_correct += (predicted == labels).sum().item()
        running_loss += loss.item()

        if i % report_interval == report_interval - 1:
            print(f'\t[{i + 1}/{len(train_loader)}] loss: {running_loss / report_interval} accuracy: {np.round((running_correct/running_total)*100, 4)}%')
            logging.info(f'\t[{i + 1}/{len(train_loader)}] loss: {running_loss / report_interval} accuracy: {np.round((running_correct/running_total)*100, 4)}%')
            running_loss = 0.0

    return running_loss / len(train_loader)

def test_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in tqdm(test_loader):
            inputs, labels = data['image'], data['label']
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.logits, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return correct / total

    import torch
from tqdm import tqdm

def test_model(model, test_loader, device='cuda'):
    model.eval()  # Set model to evaluation mode
    correct = 0
    total = 0
    class_correct = {}
    class_total = {}

    with torch.no_grad():
        for data in tqdm(test_loader):
            inputs, labels = data['image'], data['label']
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.logits, 1)
            
            # Update overall accuracy counts
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # Update per class accuracy counts
            for label, prediction in zip(labels, predicted):
                if label.item() not in class_correct:
                    class_correct[label.item()] = 0
                    class_total[label.item()] = 0
                class_correct[label.item()] += (prediction == label).item()
                class_total[label.item()] += 1

    # Calculate overall accuracy
    overall_accuracy = correct / total if total > 0 else 0

    # Calculate per class accuracy
    class_accuracies = {}
    for label in class_correct:
        class_accuracies[label] = (class_correct[label] / class_total[label]
                                   if class_total[label] > 0 else 0)

    return overall_accuracy, class_accuracies


In [7]:
# Define Hyperparameters
NUM_EPOCHS = 5
LR = 0.001
BATCH_SIZE = 8
NUM_WORKERS = 4

BACKBONE_FROZEN = True

if BACKBONE_FROZEN:
    for param in model.dinov2.parameters():
        param.requires_grad = False

model = model.to(device)

# Define Loss and optimizers
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)

# Model save paths
save_dir = f"{DATA_PATH}/Models/Assessor/DinoV2/batch_size_{BATCH_SIZE}_LR_{LR}_NUM_EPOCHS_{NUM_EPOCHS}/"
if not os.path.exists(save_dir):
    os.mkdir(save_dir)

train_loader, test_loader = setup_data_loaders(train_transform, test_transform, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

In [8]:
if logging.root.handlers:
    # If there are handlers, clear them
    logging.root.handlers = []
logging.basicConfig(filename=f'{save_dir}log.log', filemode='w', level=logging.INFO, format='%(asctime)s - %(message)s')

print('Starting Training...')
# accuracy = test_model(model, test_loader)
# print(f"Initial accuracy: {np.round(accuracy*100, 4)}%")
# logging.warning(f"Initial accuracy:  {np.round(accuracy*100, 4)}%")
for epoch in range(NUM_EPOCHS):
    print(f"Epoch {epoch}:")
    logging.info(f"Epoch {epoch}:")
    train_epoch(model, train_loader, loss_fn, optimizer, report_interval=1)
    # Save the model 
    torch.save(model, f"{save_dir}Epoch_{epoch}")

    # Evaluate the accuracy after each epoch
    accuracy, class_accuracy = test_model(model, test_loader)
    print(f"Validation Accuracy after epoch {epoch}: {np.round(accuracy*100, 2)}%")
    logging.info(f"Validation Accuracy after epoch {epoch}: {np.round(accuracy*100, 2)}%")
    class_accuracies = {id2label[k]:v for k,v in class_accuracy.items()}
    logging.info(dict(sorted(class_accuracies.items(), key=lambda item: item[1])))

Starting Training...
Epoch 0:


  1%|          | 1/100 [00:04<06:59,  4.24s/it]

	[1/100] loss: 1.8247530460357666 accuracy: 37.5%


  2%|▏         | 2/100 [00:06<05:09,  3.16s/it]

	[2/100] loss: 1.8165723085403442 accuracy: 43.75%


  3%|▎         | 3/100 [00:09<04:32,  2.81s/it]

	[3/100] loss: 2.9165256023406982 accuracy: 37.5%


  4%|▍         | 4/100 [00:11<04:13,  2.65s/it]

	[4/100] loss: 0.9054660201072693 accuracy: 46.875%


  5%|▌         | 5/100 [00:13<04:02,  2.55s/it]

	[5/100] loss: 1.9569811820983887 accuracy: 47.5%


  6%|▌         | 6/100 [00:16<03:54,  2.50s/it]

	[6/100] loss: 3.456073760986328 accuracy: 39.5833%


  7%|▋         | 7/100 [00:18<03:49,  2.46s/it]

	[7/100] loss: 1.518676519393921 accuracy: 42.8571%


  8%|▊         | 8/100 [00:21<03:44,  2.44s/it]

	[8/100] loss: 1.5953742265701294 accuracy: 42.1875%


  9%|▉         | 9/100 [00:23<03:40,  2.43s/it]

	[9/100] loss: 2.1794373989105225 accuracy: 40.2778%


 10%|█         | 10/100 [00:25<03:37,  2.42s/it]

	[10/100] loss: 1.9936000108718872 accuracy: 40.0%


 11%|█         | 11/100 [00:28<03:34,  2.41s/it]

	[11/100] loss: 2.3001976013183594 accuracy: 36.3636%


 12%|█▏        | 12/100 [00:30<03:31,  2.41s/it]

	[12/100] loss: 4.134510040283203 accuracy: 33.3333%


 13%|█▎        | 13/100 [00:32<03:29,  2.40s/it]

	[13/100] loss: 1.417230486869812 accuracy: 34.6154%


 14%|█▍        | 14/100 [00:35<03:26,  2.40s/it]

	[14/100] loss: 1.1089272499084473 accuracy: 36.6071%


 15%|█▌        | 15/100 [00:37<03:23,  2.40s/it]

	[15/100] loss: 1.8612641096115112 accuracy: 37.5%


 16%|█▌        | 16/100 [00:40<03:21,  2.40s/it]

	[16/100] loss: 1.8556301593780518 accuracy: 35.9375%


 17%|█▋        | 17/100 [00:42<03:18,  2.40s/it]

	[17/100] loss: 1.9915390014648438 accuracy: 35.2941%


 18%|█▊        | 18/100 [00:44<03:16,  2.40s/it]

	[18/100] loss: 2.295224666595459 accuracy: 36.1111%


 19%|█▉        | 19/100 [00:47<03:14,  2.40s/it]

	[19/100] loss: 1.5759873390197754 accuracy: 36.1842%


 20%|██        | 20/100 [00:49<03:11,  2.40s/it]

	[20/100] loss: 2.015655517578125 accuracy: 36.25%


 21%|██        | 21/100 [00:52<03:09,  2.40s/it]

	[21/100] loss: 1.7931270599365234 accuracy: 36.3095%


 22%|██▏       | 22/100 [00:54<03:06,  2.40s/it]

	[22/100] loss: 2.5273492336273193 accuracy: 35.7955%


 23%|██▎       | 23/100 [00:56<03:04,  2.40s/it]

	[23/100] loss: 1.5743407011032104 accuracy: 36.9565%


 24%|██▍       | 24/100 [00:59<03:02,  2.40s/it]

	[24/100] loss: 1.325131893157959 accuracy: 37.5%


 25%|██▌       | 25/100 [01:01<02:59,  2.40s/it]

	[25/100] loss: 1.4674713611602783 accuracy: 37.0%


 26%|██▌       | 26/100 [01:04<02:57,  2.40s/it]

	[26/100] loss: 2.13875150680542 accuracy: 36.0577%


 27%|██▋       | 27/100 [01:06<02:54,  2.40s/it]

	[27/100] loss: 1.5030945539474487 accuracy: 37.037%


 28%|██▊       | 28/100 [01:08<02:52,  2.40s/it]

	[28/100] loss: 0.8544535636901855 accuracy: 37.9464%


 29%|██▉       | 29/100 [01:11<02:50,  2.40s/it]

	[29/100] loss: 1.9865882396697998 accuracy: 37.931%


 30%|███       | 30/100 [01:13<02:47,  2.39s/it]

	[30/100] loss: 1.1491241455078125 accuracy: 37.9167%


 31%|███       | 31/100 [01:16<02:45,  2.40s/it]

	[31/100] loss: 3.2845852375030518 accuracy: 37.0968%


 32%|███▏      | 32/100 [01:18<02:42,  2.40s/it]

	[32/100] loss: 1.079061508178711 accuracy: 37.8906%


 33%|███▎      | 33/100 [01:20<02:40,  2.40s/it]

	[33/100] loss: 1.0771944522857666 accuracy: 38.2576%


 34%|███▍      | 34/100 [01:23<02:38,  2.40s/it]

	[34/100] loss: 2.416074275970459 accuracy: 37.8676%


 35%|███▌      | 35/100 [01:25<02:35,  2.40s/it]

	[35/100] loss: 1.0023466348648071 accuracy: 38.9286%


 36%|███▌      | 36/100 [01:28<02:33,  2.40s/it]

	[36/100] loss: 1.4212942123413086 accuracy: 39.2361%


 37%|███▋      | 37/100 [01:30<02:30,  2.40s/it]

	[37/100] loss: 1.2923710346221924 accuracy: 39.527%


 38%|███▊      | 38/100 [01:32<02:28,  2.40s/it]

	[38/100] loss: 1.9053959846496582 accuracy: 38.8158%


 39%|███▉      | 39/100 [01:35<02:26,  2.39s/it]

	[39/100] loss: 2.453493595123291 accuracy: 38.7821%


 40%|████      | 40/100 [01:37<02:23,  2.40s/it]

	[40/100] loss: 1.7455047369003296 accuracy: 39.0625%


 40%|████      | 40/100 [01:40<02:30,  2.50s/it]


KeyboardInterrupt: 