# Setup
## Constants

In [None]:
!pip install transformers -q
!pip install torch -q
!pip install torchvision -q

In [None]:
# False if you have already created and saved a .pth file to PTH_SAVE_PATH
CREATE_NEW_DATASET = True

# train, test, val set size. Should sum to 1
SET_SIZES = {
    "train": 0.8,
    "test": 0.1,
    "val": 0.1,
}

# samples per class in uniform dataset
N_SAMPLES = 400

# path to dataset (do not change)
HM_DATA_PATH = "../dataset/"

# path to pth saves (do not change)
PTH_SAVE_PATH = "../pth/"

## Imports

In [None]:
import os, sys, random, importlib, transformers, itertools, copy
import numpy as np, torch.nn as nn, torch, seaborn as sns, matplotlib.pyplot as plt, pandas as pd
from tqdm import tqdm
from sklearn.metrics import classification_report, confusion_matrix
from torch.utils.data import DataLoader
print(os.getcwd())# Our own files
# sys.path.append('./src/')
import model_functions, utils, training, datasets
def set_seed(seed):# reproducable
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed(0)

In [None]:
def update():# if you change our files
    import model_functions, utils, training, datasets
    for lib in [model_functions, utils, training, datasets]:
        importlib.reload(lib)# issues with not updating
update()

In [None]:
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available(): # For apple silicon
    device = 'mps'
print("Using device:", device)

In [None]:
model = transformers.CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = transformers.CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
processor.feature_extractor.do_rescale = False # make sure image values: False=> [0-1] and True=> [0,255]

# Dataset

### Full dataset, run once

In [None]:
column_name = 'garment_group_name'

In [None]:
df = pd.read_csv(HM_DATA_PATH+'articles_filtered.csv')
embs = torch.load(HM_DATA_PATH+'embedds.pth', weights_only=True) # all 100k embeddings
labs = torch.load(HM_DATA_PATH+'labels.pth', weights_only=True).tolist() #  100k labels
hmd = datasets.HMDatasetDuplicates(embs, np.array(labs), df)
print(hmd.article_id2suclass(694805002, column_name))
print(len(labs))
BALANCED = False

In [None]:
update()
set_sizes = {"train": 0.8, "val": 0.1}
data = datasets.datasets(embs, np.array(labs), df, set_sizes, True)# takes 3 min

### Subsets

In [None]:
exclude_classes = [
    "Unknown",
    "Special Offers",
    "Woven/Jersey/Knitted mix Baby",
]

In [None]:
update()
dataloaders_imbalanced = datasets.get_dataloaders(column_name, data, 5000, exclude_classes, 324)# look at Resource Utilization to see if capping

In [None]:
BALANCED = False
dataloaders = dataloaders_imbalanced

## LoRA experiments on garment_group_name

In [None]:
file_name = "lora-cap-5000-2-120-start_lora.pth"

Lora, unweighed, no hard mining

In [None]:

def run_lora(weighted, hard_mining, fc=False):
    update()

    ranks = [0, 0, 0, 0, 0, 0,0, 0,0, 0, 256, 256]  # Only apply LoRA with rank 64 to the last layer

    #ranks = [256, 256, 256, 256, 256, 256, 256, 256,256, 256, 256, 256]  # Only apply LoRA with rank 64 to the last layer
    lr = 1e-04
    wd = 0.001
    epochs_num = 100
    lora_layers = []

    clip = {'m': copy.deepcopy(model), 'p': processor} # do not load each time
    lora_layers = model_functions.apply_lora_to_transformer(clip['m'].text_model.encoder.layers , lora_layers, ranks)
    lora_params_attention = model_functions.get_lora_params(clip['m'], print_layer = True)


    ft = training.FinetuneCLIP(dataloaders, clip, epochs = epochs_num )
    ft.conf = {'epochs': epochs_num, 'balanced':BALANCED}
    ft.model_prefix = f"draft-experiments/weighted={weighted}_hard-mining={hard_mining}_fc={fc}"
    ft.hard_mining = hard_mining
    ft.weighted = weighted

    # Initialize LoRA training with current hyperparameters
    ft.tt['soft'], ft.tt['LoRA'], ft.tt['image_fc'] = 0, 1, 0 # Enable LoRA
    if fc:
        ft.tt['image_fc'] = 1


    ft.initialize({'LoRA': lora_params_attention, 'lr': lr, 'weight_decay': wd, 'num_soft':0, 'add':''},
                load=False, file_name=file_name)


    ft.count_parameters()
    #all_predictions, all_labels, acc = ft.eval(False)

    # Train the model
    ft.es['pat']=30
    ft.train()

    # Evaluate the model
    all_predictions, all_labels, acc = ft.eval(False)
    utils.confussion_matrix(all_labels, all_predictions, list(dataloaders['test'].dataset.class_to_id.keys()), F1=False)
    ft.plot_loss_key('train', 'final')
    ft.plot_loss_key('val', 'final')

    print(f"Accuracy for ")

In [None]:
for weighted in [True, False]:
    for hard_mining in [True, False]:
        run_lora(weighted, hard_mining)