In [2]:
!pip install matplotlib
!pip install pandas
!pip install scikit-learn
!pip install pyarrow
!pip install wandb

import ssl
import pandas as pd
import wandb

ssl._create_default_https_context = ssl._create_unverified_context
import os
from torch.utils.data import DataLoader
from torch import nn
from torchvision.io import read_image
from torchvision.transforms import v2
import torch
import torchvision.transforms as T
import json
import urllib
import requests
from PIL import Image
from io import BytesIO
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from tqdm import tqdm




Found Intel OpenMP ('libiomp') and LLVM OpenMP ('libomp') loaded at
the same time. Both libraries are known to be incompatible and this
can cause random crashes or deadlocks on Linux when loaded in the
same Python program.
Using threadpoolctl may cause crashes or deadlocks. For more
information and possible workarounds, please see
    https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md

Matplotlib is building the font cache; this may take a moment.


In [2]:
# Set a device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

# Configuration

In [3]:
model_config = {
    'init_lr': 0.01,
    'batch_size': 32,
    'epochs': 25,
    'empty_image_representation': 'zero_matrix',  #  zero_matrix, torch_empty
    'dino_architecture': 'small',
    'dataset': 'polyvore_63eb50dc58d97415384467bef7b3c9e1bd6c96e06ad19571b6bc15e9dd5af262.parquet',
    'model_forward_version': 'fast',  # slow, fast
    'hidden_layer_neuron_count': 64,
    'dropout_probability': 0.1,
    'regularisation': 'l1',
    'regularisation_weight': 1
}

dataset_folder_root_path = '../datasets'
dataset_path = f'{dataset_folder_root_path}/imageBasedModel/polyvore/{model_config["dataset"]}'

# Data Augmentation

In [3]:
data_transforms = {
    'train': v2.Compose([
        v2.PILToTensor(),
        T.Resize(224, interpolation=T.InterpolationMode.BICUBIC, antialias=True),
        v2.CenterCrop(224),
        v2.RandomHorizontalFlip(),
        v2.RandomPerspective(fill=255),
        v2.RandomAffine(30, fill=255),
        # v2.AutoAugment(v2.AutoAugmentPolicy.IMAGENET),
        v2.ConvertImageDtype(torch.float32),
        v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
    ]),
    'val': v2.Compose([
        v2.PILToTensor(),
        T.Resize(224, interpolation=T.InterpolationMode.BICUBIC, antialias=True),
        v2.CenterCrop(224),
        v2.ConvertImageDtype(torch.float32),
        v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
    ]),
    'test': v2.Compose([
        v2.PILToTensor(),
        T.Resize(224, interpolation=T.InterpolationMode.BICUBIC, antialias=True),
        v2.CenterCrop(224),
        v2.ConvertImageDtype(torch.float32),
        v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
    ]),
}


data_transforms


{'train': Compose(
       PILToTensor()
       Resize(size=224, interpolation=bicubic, max_size=None, antialias=True)
       CenterCrop(size=(224, 224))
       RandomHorizontalFlip(p=0.5)
       RandomPerspective(p=0.5, distortion_scale=0.5, interpolation=InterpolationMode.BILINEAR, fill=255)
       RandomAffine(degrees=[-30.0, 30.0], interpolation=InterpolationMode.NEAREST, fill=255)
       ConvertImageDtype()
       Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False)
 ),
 'val': Compose(
       PILToTensor()
       Resize(size=224, interpolation=bicubic, max_size=None, antialias=True)
       CenterCrop(size=(224, 224))
       ConvertImageDtype()
       Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False)
 ),
 'test': Compose(
       PILToTensor()
       Resize(size=224, interpolation=bicubic, max_size=None, antialias=True)
       CenterCrop(size=(224, 224))
       ConvertImageDtype()
       Normalize(mean=[0.485, 0.456, 0.406], s

# Set seeds

In [4]:
def fix_random_seeds(seed=12345):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)


fix_random_seeds()

# Load feature extractor model

In [5]:
backbone_archs = {
    'small': 'vits14',
    'base': 'vitb14',
    'large': 'vitl14',
    'giant': 'vitg14',
}

backbone_arch = backbone_archs[model_config['dino_architecture']]
backbone_name = f'dinov2_{backbone_arch}'
feature_extraction_model = torch.hub.load('facebookresearch/dinov2', backbone_name).eval().to(device)

Using cache found in /home/jovyan/.cache/torch/hub/facebookresearch_dinov2_main


In [6]:
feature_extraction_model.embed_dim

384

# freeze weights of models

In [7]:
for param in feature_extraction_model.parameters():
    param.requires_grad = False

# Testing model

In [8]:
# Get ImageNet labels
imagenet_class_url = 'https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json'
imagenet_classes = json.loads(urllib.request.urlopen(imagenet_class_url).read())


# # Set a device
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


#https://www.modelbit.com/blog/deploying-dinov2-for-image-classification-with-modelbit#:~:text=To%20do%20this%2C%20simply%20use,machine%20to%20your%20Colab%20directory.&text=Next%2C%20you'll%20want%20to,ImageNet%20preprocessing%20on%20the%20image.&text=Now%2C%20we%20can%20pass%20the,a%20class%20ID%20and%20label.
def dinov2_classifier(img_url):
    response = requests.get(img_url)
    image = Image.open(BytesIO(response.content))

    # Preprocess the image
    transform = T.Compose([
        T.Resize(256, interpolation=T.InterpolationMode.BICUBIC),
        T.CenterCrop(224),
        T.ToTensor(),
        T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ])
    image = transform(image)

    # Move the image to the GPU if available
    image = image.to(device)

    # Extract the features
    with torch.no_grad():
        features = torch.squeeze(feature_extraction_model(image.unsqueeze(0)))
        # print(features)
        # print(features.shape)

    # Print the features
    return {'index': features.argmax(-1).item(),
            'label': imagenet_classes[features.argmax(-1).item()]
            }

In [9]:
import time

start_time = time.time()
dinov2_classifier(
    "https://www.apple.com/v/iphone/home/bu/images/meta/iphone__ky2k6x5u6vue_og.png")
print("--- %s seconds ---" % (time.time() - start_time))

--- 2.288867712020874 seconds ---


# Implementing Custom Model

In [10]:
# Class for test dataset
def get_image(img_path):
    if img_path is not None:
        img_path = img_path.replace('raw/images', 'resized/256x256')

    if img_path is None:
        if model_config["empty_image_representation"] == "zero_matrix":
            return torch.zeros(3, 224, 224)
        elif model_config["empty_image_representation"] == "torch_empty":
            return torch.empty(3, 224, 224)
        else:
            raise Exception("Wrong configuration value for key empty_image_representation in model_configuration")
    else:
        return read_image(f'{dataset_folder_root_path}/{img_path}')


class OutfitClassifier(nn.Module):

    @property
    def embed_dim(self):
        return self._embed_dim

    def __init__(self):
        super(OutfitClassifier, self).__init__()
        self._embed_dim = feature_extraction_model.embed_dim * 5

        feature_extraction_model.eval().to(device)

        hidden_size = model_config['hidden_layer_neuron_count']

        self.trainable_model = nn.Sequential(
            nn.Linear(self._embed_dim, hidden_size),
            nn.ReLU(),
            nn.Dropout(model_config['dropout_probability']),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(model_config['dropout_probability']),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(model_config['dropout_probability']),
            nn.Linear(hidden_size, 1),
            nn.Sigmoid()
        )

        self.trainable_model.to(device)
        self.counter = 0

    def forward(self, X):
        number_of_rows = X.shape[0]

        with torch.no_grad():
            # Reshape X to concatenate along the batch dimension
            # New shape will be [5 * batch_size, channels, height, width]
            dino_input = X.view(number_of_rows * 5, 3, 224, 224)

            batch_features = feature_extraction_model(dino_input)

            # dino batch features torch.Size([160, 384])

            # Reshape the features to [batch_size, embed_dim]
            batch_features = torch.reshape(batch_features, (int(batch_features.shape[0] / 5), self.embed_dim))
            batch_features = batch_features.to(device)
            batch_features.requires_grad_()

        self.counter += 1

        return self.trainable_model.forward(batch_features)

    def compute_l1_loss(self, w):
        return torch.abs(w).sum()


In [11]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data_frame, tfms: v2.Compose, name):
        self.df = data_frame
        self.tfms = tfms

        img_accessoires = []
        img_innerwear = []
        img_bottomwear = []
        img_shoe = []
        img_outerwear = []
        valid_outfit = []
        for index, outfit in tqdm(self.df.iterrows(), total=self.df.shape[0], desc=f'Loading {name} dataset'):
            img_accessoires.append(get_image(outfit['Accessoire_imagePath']))
            img_innerwear.append(get_image(outfit['Innerwear_imagePath']))
            img_bottomwear.append(get_image(outfit['Bottomwear_imagePath']))
            img_shoe.append(get_image(outfit['Shoes_imagePath']))
            img_outerwear.append(get_image(outfit['Outerwear_imagePath']))
            valid_outfit.append(outfit['valid_outfit'])

        self.feature_df = pd.DataFrame({
            'Accessoire_imagePath': img_accessoires,
            'Innerwear_imagePath': img_innerwear,
            'Bottomwear_imagePath': img_bottomwear,
            'Shoes_imagePath': img_shoe,
            'Outerwear_imagePath': img_outerwear,
            'valid_outfit': valid_outfit
        }, index=self.df.index)

    def __getitem__(self, index):
        #print(index)
        #start_time = time.time()
        outfit = self.feature_df.iloc[index]
        img_accessoire = self.tfms(outfit['Accessoire_imagePath'].to(device))
        img_innerwear = self.tfms(outfit['Innerwear_imagePath'].to(device))
        img_bottomwear = self.tfms(outfit['Bottomwear_imagePath'].to(device))
        img_shoe = self.tfms(outfit['Shoes_imagePath'].to(device))
        img_outerwear = self.tfms(outfit['Outerwear_imagePath'].to(device))

        target_variable = torch.tensor([outfit['valid_outfit']]).to(torch.float).to(device)

        feature_tensor = torch.cat([
            img_accessoire.unsqueeze(0),
            img_innerwear.unsqueeze(0),
            img_bottomwear.unsqueeze(0),
            img_shoe.unsqueeze(0),
            img_outerwear.unsqueeze(0)
        ]).to(device)

        #print(f'finished dataset get item {time.time() - start_time}')

        return feature_tensor, target_variable

    def __len__(self):
        return self.df.shape[0]

In [12]:
df = pd.read_parquet(
    dataset_path
)
#df = df.iloc[:100].copy()

df

Unnamed: 0,Innerwear_imagePath,Bottomwear_imagePath,Accessoire_imagePath,Shoes_imagePath,Outerwear_imagePath,valid_outfit
120161271,raw/images/120161271/1.jpg,raw/images/120161271/2.jpg,,,,0
143656996,raw/images/143656996/1.jpg,raw/images/143656996/3.jpg,raw/images/143656996/5.jpg,,,0
216470135,raw/images/216470135/1.jpg,raw/images/216470135/2.jpg,,raw/images/216470135/3.jpg,,1
216220312,raw/images/216220312/1.jpg,raw/images/216220312/2.jpg,raw/images/216220312/4.jpg,raw/images/216220312/3.jpg,,1
192203629,raw/images/192203629/2.jpg,raw/images/192203629/3.jpg,raw/images/192203629/5.jpg,,raw/images/192203629/1.jpg,0
...,...,...,...,...,...,...
201717504,raw/images/201717504/1.jpg,raw/images/201717504/2.jpg,raw/images/201717504/4.jpg,raw/images/201717504/3.jpg,,1
216589548,raw/images/216589548/1.jpg,raw/images/216589548/3.jpg,raw/images/216589548/5.jpg,,raw/images/216589548/2.jpg,1
216860218,raw/images/216860218/1.jpg,raw/images/216860218/3.jpg,raw/images/216860218/5.jpg,,raw/images/216860218/2.jpg,1
211099953,raw/images/211099953/1.jpg,raw/images/211099953/2.jpg,,raw/images/211099953/3.jpg,,1


In [13]:
from sklearn.model_selection import train_test_split

train, test = train_test_split(df, test_size=0.20, random_state=42, stratify=df['valid_outfit'])
train, validation = train_test_split(train, test_size=0.25, random_state=42, stratify=train['valid_outfit'])

train

Unnamed: 0,Innerwear_imagePath,Bottomwear_imagePath,Accessoire_imagePath,Shoes_imagePath,Outerwear_imagePath,valid_outfit
140890817,raw/images/140890817/1.jpg,raw/images/140890817/2.jpg,,,,0
203955931,raw/images/203955931/1.jpg,raw/images/203955931/2.jpg,raw/images/203955931/3.jpg,raw/images/203955931/7.jpg,,0
216947310,raw/images/216947310/1.jpg,raw/images/216947310/3.jpg,raw/images/216947310/5.jpg,raw/images/216947310/4.jpg,,0
210526001,raw/images/210526001/1.jpg,raw/images/210526001/2.jpg,raw/images/210526001/4.jpg,raw/images/210526001/3.jpg,,0
215262893,raw/images/215262893/1.jpg,raw/images/215262893/2.jpg,raw/images/215262893/4.jpg,,,1
...,...,...,...,...,...,...
146338069,raw/images/146338069/1.jpg,raw/images/146338069/2.jpg,raw/images/146338069/4.jpg,raw/images/146338069/3.jpg,,0
216252775,raw/images/216252775/1.jpg,raw/images/216252775/3.jpg,raw/images/216252775/5.jpg,,raw/images/216252775/2.jpg,1
213715463,raw/images/213715463/1.jpg,raw/images/213715463/2.jpg,,,,1
170950241,raw/images/170950241/1.jpg,raw/images/170950241/3.jpg,raw/images/170950241/5.jpg,raw/images/170950241/4.jpg,raw/images/170950241/2.jpg,1


In [14]:
CROP_SIZE = 256
BATCH_SIZE = model_config['batch_size']
NUM_WORKERS = 0

image_datasets = {
    'train': CustomDataset(train, data_transforms['train'], 'training'),
    'val': CustomDataset(validation, data_transforms['val'], 'validation'),
    'test': CustomDataset(test, data_transforms['test'], 'test')
}

dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=BATCH_SIZE,
                                              shuffle=True, num_workers=NUM_WORKERS)
               for x in ['train', 'val']}


def get_image_for_matplot_lib(img):
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    #fig = plt.figure(figsize=(cols * 3, rows * 3))
    img = img.cpu().numpy().transpose((1, 2, 0))
    img = std * img + mean
    return np.clip(img, 0, 1)


def show_batch(imgs, classification):
    number_of_clothing_items = len(imgs)

    number_of_batch_rows = len(imgs[0])

    for batch_item_index in range(number_of_batch_rows):
        f, axarr = plt.subplots(1, number_of_clothing_items + 1, figsize=(15, 2))
        f.patch.set_facecolor('black')

        clothing_item_accessoire = get_image_for_matplot_lib(imgs[0][batch_item_index][0])
        clothing_item_inner_wear = get_image_for_matplot_lib(imgs[1][batch_item_index][0])
        clothing_item_bottom_wear = get_image_for_matplot_lib(imgs[2][batch_item_index][0])
        clothing_item_shoes = get_image_for_matplot_lib(imgs[3][batch_item_index][0])
        clothing_item_over_wear = get_image_for_matplot_lib(imgs[4][batch_item_index][0])

        clothing_items = [clothing_item_accessoire, clothing_item_over_wear, clothing_item_inner_wear,
                          clothing_item_bottom_wear, clothing_item_shoes]
        for cloting_item_axis_index in range(len(clothing_items)):
            ax = axarr[cloting_item_axis_index]
            ax.imshow(clothing_items[cloting_item_axis_index])
            ax.axis('off')

        ax = axarr[cloting_item_axis_index + 1]
        ax.set_xlim([0, 0.5])
        ax.set_ylim([0, 0.5])

        is_a_good_outfit = classification[batch_item_index] == 1

        label_font_size = 20
        if is_a_good_outfit:
            ax.text(0.5, 0.5, 'good', horizontalalignment='center', transform=ax.transAxes, weight='bold',
                    color='green', fontsize=label_font_size)
        else:
            ax.text(0.5, 0.5, 'bad', horizontalalignment='center', transform=ax.transAxes, weight='bold', color='red',
                    fontsize=label_font_size)

        ax.axis('off')

        f.tight_layout()


inputs, classification = next(iter(dataloaders['train']))


Loading training dataset: 100%|██████████| 4033/4033 [05:44<00:00, 11.69it/s]
Loading validation dataset: 100%|██████████| 1345/1345 [01:46<00:00, 12.64it/s]
Loading test dataset: 100%|██████████| 1345/1345 [01:41<00:00, 13.26it/s]


In [15]:
#show_batch(inputs, classification)

In [16]:
def train_loop(dataloader, feature_model, loss_fn, optimizer):
    # Set the model to training mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    feature_model.train()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    running_loss = 0.0
    running_corrects = 0
    feature_model.counter = 0

    for batch, (X, y) in tqdm(enumerate(dataloader), desc='Training', total=num_batches):
        # Compute prediction and loss
        pred = feature_model(X)

        loss = loss_fn(pred, y)

        # Compute L1 loss component
        if model_config['regularisation'] == 'l1':
            # lasso regularization
            l1_weight = model_config['regularisation_weight']
            l1_parameters = []
            for parameter in feature_model.trainable_model.parameters():
                l1_parameters.append(parameter.view(-1))
            l1 = l1_weight * feature_model.compute_l1_loss(torch.cat(l1_parameters))

            # Add L1 loss component
            loss += l1
        elif model_config['regularisation'] == 'l2':
            #todo: implement
            pass
        else:
            raise Exception(
                f'configuration value for key regularisation was {model_config["regularisation"]} which is not a valid configuraiton value')

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Statistics
        running_loss += loss.item()
        running_corrects += (pred.round() == y).type(torch.float).sum().item()

    epoch_loss = running_loss / num_batches
    epoch_acc = running_corrects / size

    return epoch_acc, epoch_loss

In [17]:
@torch.inference_mode()
def val_loop(dataloader, feature_model, loss_fn):
    # Set the model to evaluation mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    feature_model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    val_loss, val_acc = 0.0, 0

    with torch.no_grad():
        for X, y in tqdm(dataloader, desc='Validation', total=num_batches):
            pred = feature_model(X)
            val_loss += loss_fn(pred, y).item()
            val_acc += (pred.round() == y).type(torch.float).sum().item()

    val_loss /= num_batches
    val_acc /= size

    return val_acc, val_loss

In [18]:
outfit_classifier = OutfitClassifier().to(device)

outfit_classifier

OutfitClassifier(
  (trainable_model): Sequential(
    (0): Linear(in_features=1920, out_features=64, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=64, out_features=64, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.1, inplace=False)
    (6): Linear(in_features=64, out_features=64, bias=True)
    (7): ReLU()
    (8): Dropout(p=0.1, inplace=False)
    (9): Linear(in_features=64, out_features=1, bias=True)
    (10): Sigmoid()
  )
)

In [19]:
outfit_classifier = OutfitClassifier()

wandb.init(
    # set the wandb project where this run will be logged
    project="ReWear - Outfit Recommender (DSPRO2)",
    # track hyperparameters and run metadata
    config=model_config
)

loss_fn = nn.BCELoss()
optimizer = torch.optim.Adam(
    outfit_classifier.parameters(),
    lr=model_config['init_lr']
)
# optimizer = torch.optim.SGD(
#     linear_classifier.parameters(),
#     lr=model_config['init_lr'],
#     momentum=0.9,
#     weight_decay=0, # we do not apply weight decay
# )


scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, model_config['epochs'], eta_min=0)

pytorch_total_params = sum(p.numel() for p in outfit_classifier.parameters() if p.requires_grad)
pytorch_total_params_trainable = sum(
    p.numel() for p in outfit_classifier.trainable_model.parameters() if p.requires_grad)

print(f'total trainable params: {pytorch_total_params}')
print(f'total trainable params trainable model: {pytorch_total_params_trainable}')

[34m[1mwandb[0m: Currently logged in as: [33mdata_scientist_24[0m ([33mrz_datascience[0m). Use [1m`wandb login --relogin`[0m to force relogin


total trainable params: 131329
total trainable params trainable model: 131329


In [None]:
import numpy as np


def train_model(config, data_dir=None):
    wandb.init(
        # set the wandb project where this run will be logged
        project="ReWear - Outfit Recommender (DSPRO2) Dino v2 based",
        # track hyperparameters and run metadata
        config=config
    )

    VAL_FREQ = 1
    best_acc = 0.0
    best_acc_loss = np.inf
    train_data = []
    for t in range(model_config['epochs']):
        print(f'Epoch {t + 1}\n-------------------------------')

        train_acc, train_loss = train_loop(dataloaders['train'], outfit_classifier, loss_fn, optimizer)
        train_data.append({
            'phase': 'train',
            'epoch': t,
            'lr': optimizer.param_groups[0]["lr"],
            'accuracy': train_acc,
            'loss': train_loss,
        })


        scheduler.step()

        val_acc, val_loss = val_loop(dataloaders['val'], outfit_classifier, loss_fn)
        train_data.append({
            'phase': 'val',
            'epoch': t,
            'lr': optimizer.param_groups[0]["lr"],
            'accuracy': val_acc,
            'loss': val_loss
        })

        wandb.log(
            {
                'epoch': t,
                'lr': optimizer.param_groups[0]["lr"],
                'training_accuracy': train_acc,
                'training_loss': train_loss,
                'validation_accuracy': val_acc,
                'validation_loss': val_loss
            }
        )
        #print(f'Validation:\n    val_acc = {val_acc}, val_loss = {val_loss}')
        if (val_acc == best_acc and val_loss < best_acc_loss) or (val_acc > best_acc):
            best_acc, best_acc_loss = val_acc, val_loss
            save_dict = {
                'epoch': t + 1,
                'state_dict': outfit_classifier.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'best_acc': best_acc,
                'best_loss': best_acc_loss
            }
            torch.save(save_dict, os.path.join(wandb.run.dir, 'dino_classifier_ckpt.pth'))

        print('\n')
    print('Training completed.')
    wandb.save('dino_classifier_ckpt.pth')
    wandb.finish()

Epoch 1
-------------------------------


Training: 100%|██████████| 127/127 [02:54<00:00,  1.38s/it]
Validation: 100%|██████████| 43/43 [00:52<00:00,  1.21s/it]


Best accuracy = 0.5479553903345725, best_loss = 0.690824102523715


Epoch 2
-------------------------------


Training: 100%|██████████| 127/127 [02:54<00:00,  1.37s/it]
Validation:   9%|▉         | 4/43 [00:04<00:48,  1.24s/it]

In [None]:
from ray import tune

config = {
    # "l1": tune.choice([2 ** i for i in range(9)]),
    # "l2": tune.choice([2 ** i for i in range(9)]),
    # "lr": tune.loguniform(1e-4, 1e-1),
    # "batch_size": tune.choice([2, 4, 8, 16])
    'init_lr': 0.01,
    'batch_size': 32,
    'epochs': 25,
    'empty_image_representation': 'zero_matrix',  #  zero_matrix, torch_empty
    'dino_architecture': 'small',
    'training_dataset': 'polyvore_63eb50dc58d97415384467bef7b3c9e1bd6c96e06ad19571b6bc15e9dd5af262.parquet',
    'testing_dataset': '',
    'model_forward_version': 'fast',  # slow, fast
    'hidden_layer_neuron_count': 64,
    'dropout_probability': 0.1,
}

tune.run()