# Code for training and testing a classification model to use for automatic categorization/labeling of clothing images from online shops.

# Class 1 represents an unaltered full view of a clothing item laid out on a flat surface with a plain background.

# Class 0 represents anything else (e.g. models posing in the said clothing items).

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install boto3

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting boto3
  Downloading boto3-1.26.160-py3-none-any.whl (135 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m135.9/135.9 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting botocore<1.30.0,>=1.29.160 (from boto3)
  Downloading botocore-1.29.160-py3-none-any.whl (10.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.9/10.9 MB[0m [31m79.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting jmespath<2.0.0,>=0.7.1 (from boto3)
  Downloading jmespath-1.0.1-py3-none-any.whl (20 kB)
Collecting s3transfer<0.7.0,>=0.6.0 (from boto3)
  Downloading s3transfer-0.6.1-py3-none-any.whl (79 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.8/79.8 kB[0m [31m12.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: jmespath, botocore, s3transfer, boto3
Successfully installed boto3-1.26.160 botocore-1.29.160 jmespath-1.0.

In [None]:
import shutil
import json
import gc

from tqdm.auto import tqdm

import boto3
from botocore.client import Config
from pathlib import Path

from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support
import matplotlib.pyplot as plt
import pandas as pd

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import cv2


device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
def setup_s3_connection(secrets_path='secrets.json', service_name='s3', endpoint_url='https://s3.timeweb.com', region_name='ru-1'):
    '''
    Create Boto3 client object for s3 access.

    Requires json file with s3 credentials in following format:
    {
        aws_access_key_id: "qwerty"
        aws_secret_access_key: "a1b2c3d4etc"
    }

    -----
    Args:
        secrets_path (str or Path): path to json file with s3 credentials, defaults to file named 'secrets.json' in the current folder
        service_name (str): name of service to use with the client, defaults to s3
        endpoint_url (str): the complete URL to use for the constructed client, defaults to 'https://s3.timeweb.com'
        region_name (str): the name of the region associated with the client, defaults to 'ru-1'

    Returns:
        s3 (boto3.client): service client instance
    '''

    with open(secrets_path, 'r') as j:
        secrets = json.load(j)

    s3 = boto3.client(
        service_name=service_name,
        endpoint_url=endpoint_url,
        region_name=region_name,
        aws_access_key_id=secrets['aws_access_key_id'],
        aws_secret_access_key=secrets['aws_secret_access_key'],
        config=Config(s3={'addressing_style': 'path'})
    )
    return s3

In [None]:
s3 = setup_s3_connection()

In [None]:
s3.download_file("s3-bucket", "data/short_table.csv", "products.csv")

s3.download_file("s3-bucket", "data/image_archives/bershka.zip", "bershka.zip")
s3.download_file("s3-bucket", "data/image_archives/finn_flare.zip", "finn_flare.zip")
s3.download_file("s3-bucket", "data/image_archives/guess.zip", "guess.zip")
s3.download_file("s3-bucket", "data/image_archives/marks_&_spencer.zip", "marks_&_spencer.zip")
s3.download_file("s3-bucket", "data/image_archives/roxy.zip", "roxy.zip")
s3.download_file("s3-bucket", "data/image_archives/zara.zip", "zara.zip")
s3.download_file("s3-bucket", "data/image_archives/h&m.zip", "h&m.zip")

shutil.unpack_archive("bershka.zip", "images")
shutil.unpack_archive("finn_flare.zip", "images")
shutil.unpack_archive("guess.zip", "images")
shutil.unpack_archive("marks_&_spencer.zip", "images")
shutil.unpack_archive("roxy.zip", "images")
shutil.unpack_archive("zara.zip", "images")
shutil.unpack_archive("h&m.zip", "test_images")

# Train

In [None]:
valid_files = []
for hash_1 in Path("images").iterdir():
    for hash_2 in hash_1.iterdir():
        for f in hash_2.iterdir():
            valid_files.append(f.stem)

In [None]:
len(valid_files)

123447

In [None]:
df = pd.read_csv("products.csv")

df["path"] = df["image_hash"].apply(lambda row: f"{row[:2]}/{row[2:4]}/{row}.jpg")
df

Unnamed: 0,product_link,product_name,image_original,name_of_shop,image_hash,path
0,https://cos.juun.ru/men/64_одежда/193_блейзеры...,БЛЕЙЗЕР ОБЫЧНОГО ПОКРОЯ,https://c.juun.ru/image/products/cos/9c/f6/9cf...,juun,17db7609ee3ed408fbf8f9eefcd902dd,17/db/17db7609ee3ed408fbf8f9eefcd902dd.jpg
1,https://cos.juun.ru/men/64_одежда/193_блейзеры...,БЛЕЙЗЕР ОБЫЧНОГО ПОКРОЯ,https://c.juun.ru/image/products/cos/b3/c7/b3...,juun,b1dffc1276b8e05e683fbe689baf18b6,b1/df/b1dffc1276b8e05e683fbe689baf18b6.jpg
2,https://cos.juun.ru/men/64_одежда/193_блейзеры...,БЛЕЙЗЕР ОБЫЧНОГО ПОКРОЯ,https://c.juun.ru/image/products/cos/74/cd/74...,juun,64f63cd1a4e0ae75981bb4a5f22db340,64/f6/64f63cd1a4e0ae75981bb4a5f22db340.jpg
3,https://cos.juun.ru/men/64_одежда/193_блейзеры...,БЛЕЙЗЕР ОБЫЧНОГО ПОКРОЯ,https://c.juun.ru/image/products/cos/76/16/76...,juun,26fbdd79bc1a7b64540aaea8adb37101,26/fb/26fbdd79bc1a7b64540aaea8adb37101.jpg
4,https://cos.juun.ru/men/64_одежда/193_блейзеры...,БЛЕЙЗЕР ОБЫЧНОГО ПОКРОЯ,https://c.juun.ru/image/products/cos/b9/87/b9...,juun,096691458b009d001c7846185c98de2f,09/66/096691458b009d001c7846185c98de2f.jpg
...,...,...,...,...,...,...
4495793,https://www.uniqlo.com/eu/en/product/low-rise-...,Low Rise Maternity Briefs,https://image.uniqlo.com/UQ/ST3/WesternCommon/...,uniqlo,ab75ec759ff1f4234a4175b385dc7690,ab/75/ab75ec759ff1f4234a4175b385dc7690.jpg
4495794,https://www.uniqlo.com/eu/en/product/ultra-str...,Ultra Stretch Maternity Trousers,https://image.uniqlo.com/UQ/ST3/WesternCommon/...,uniqlo,6a3b325dc5afa2a2559ee459da08f1db,6a/3b/6a3b325dc5afa2a2559ee459da08f1db.jpg
4495795,https://www.uniqlo.com/eu/en/product/ultra-str...,Ultra Stretch Maternity Denim Leggings,https://image.uniqlo.com/UQ/ST3/WesternCommon/...,uniqlo,c12ce4c9bb58b1090a16c133c6b5dcc9,c1/2c/c12ce4c9bb58b1090a16c133c6b5dcc9.jpg
4495796,https://www.uniqlo.com/eu/en/product/high-rise...,High Rise Maternity Briefs,https://image.uniqlo.com/UQ/ST3/WesternCommon/...,uniqlo,a81ebd5d2b16abd5e426f61d711aeb68,a8/1e/a81ebd5d2b16abd5e426f61d711aeb68.jpg


In [None]:
df_f = df[df["image_hash"].isin(valid_files)].reset_index(drop=True).copy()

# mapping of images to classes 0 and 1 based on the analysis of the image links
df_f.loc[df_f["name_of_shop"] == "bershka", "label"] = df_f["image_original"].apply(lambda row: 1 if any(map(row.__contains__, ["2_4_0", "2_13_0"])) else 0)
df_f.loc[df_f["name_of_shop"] == "finn_flare", "label"] = df_f["image_original"].apply(lambda row: 1 if any(map(row.__contains__, ["_100.jpg"])) else 0)
df_f.loc[df_f["name_of_shop"] == "guess", "label"] = df_f["image_original"].apply(lambda row: 1 if any(map(row.__contains__, ["/GHOST/"])) else 0)
df_f.loc[df_f["name_of_shop"] == "marks_&_spencer", "label"] = df_f["image_original"].apply(lambda row: 1 if any(map(row.__contains__, ["_90?"])) else 0)
df_f.loc[df_f["name_of_shop"] == "roxy", "label"] = df_f["image_original"].apply(lambda row: 1 if any(map(row.__contains__, [",f_", ",v_"])) else 0)
df_f.loc[df_f["name_of_shop"] == "zara", "label"] = df_f["image_original"].apply(lambda row: 1 if any(map(row.__contains__, ["_6_1_1.", "_6_2_1.", "_6_22_1."])) else 0)

df_f["label"] = df_f["label"].astype(int)
df_f

Unnamed: 0,product_link,product_name,image_original,name_of_shop,image_hash,path,label
0,https://www.finn-flare.ru/catalog/zhenskie-puh...,Куртка женская,https://cdn.finnflare.com/upload/resize_cache/...,finn_flare,b4272bf2e57f5204f63b6beb82ab1f5f,b4/27/b4272bf2e57f5204f63b6beb82ab1f5f.jpg,0
1,https://www.finn-flare.ru/catalog/zhenskie-puh...,Куртка женская,https://cdn.finnflare.com/upload/resize_cache...,finn_flare,430a9208cbfc7838188f64dbc54a55a3,43/0a/430a9208cbfc7838188f64dbc54a55a3.jpg,0
2,https://www.finn-flare.ru/catalog/zhenskie-puh...,Куртка женская,https://cdn.finnflare.com/upload/resize_cache...,finn_flare,938419b527176e86f7bb14699bf59108,93/84/938419b527176e86f7bb14699bf59108.jpg,0
3,https://www.finn-flare.ru/catalog/zhenskie-puh...,Куртка женская,https://cdn.finnflare.com/upload/resize_cache...,finn_flare,2e8cee0e8f70251b82c44b44225cd2d3,2e/8c/2e8cee0e8f70251b82c44b44225cd2d3.jpg,0
4,https://www.finn-flare.ru/catalog/zhenskie-puh...,Куртка женская,https://cdn.finnflare.com/upload/resize_cache...,finn_flare,14fd203b49a2fee4b3f1e45d78a73a2f,14/fd/14fd203b49a2fee4b3f1e45d78a73a2f.jpg,0
...,...,...,...,...,...,...,...
123442,https://www.bershka.com/rs/en/women/collaborat...,Kuromi mobile phone case,https://static.bershka.net/4/photos2/2023/V/0/...,bershka,38b1ab927b955c72a6c6fcd92bf2625d,38/b1/38b1ab927b955c72a6c6fcd92bf2625d.jpg,0
123443,https://www.bershka.com/rs/en/women/collaborat...,Kuromi mobile phone case,https://static.bershka.net/4/photos2/2023/V/0...,bershka,807ba1626c3dcfd32308e1d0a7e33bee,80/7b/807ba1626c3dcfd32308e1d0a7e33bee.jpg,0
123444,https://www.bershka.com/rs/en/women/collaborat...,Kuromi mobile phone case,https://static.bershka.net/4/photos2/2023/V/0...,bershka,a73667bf12d5ca627f91b390c1594da9,a7/36/a73667bf12d5ca627f91b390c1594da9.jpg,0
123445,https://www.bershka.com/rs/en/women/collaborat...,Kuromi mobile phone case,https://static.bershka.net/4/photos2/2023/V/0...,bershka,47460493d109a9cf4294d629416c938c,47/46/47460493d109a9cf4294d629416c938c.jpg,1


In [None]:
# manual remapping for outliers
manual = ["https://zara-ru.com/images/detailed/951/1564341800_6_5_1.jpg", # label 1
          "https://zara-ru.com/images/detailed/951/1564341800_6_5_1.jpg",
          'https://zara-ru.com/images/detailed/1034/3427711800_15_1_1.jpg',
          'https://zara-ru.com/images/detailed/1035/1255808800_6_3_1.jpg',
          'https://zara-ru.com/images/detailed/903/3548244700_6_5_1.jpg',
          'https://zara-ru.com/images/detailed/903/3548244700_6_6_1.jpg',
          'https://zara-ru.com/images/detailed/904/1255808800_6_3_1.jpg',
          'https://zara-ru.com/images/detailed/903/8073228031_6_5_1.jpg',
          'https://zara-ru.com/images/detailed/903/8073228031_6_6_1.jpg']

df_f.loc[df_f.image_original.isin(manual), 'label'] = 1

In [None]:
df_f.label.value_counts(dropna=False)

0    102869
1     20578
Name: label, dtype: int64

In [None]:
class SquarePadTensor:
    def __call__(self, image):
        _, h, w = image.shape
        s = max(w, h)
        lft = (s - w) // 2
        rgt = s - w - lft
        top = (s - h) // 2
        bot = s - h - top

        padding = (lft, top, rgt, bot)
        return transforms.functional.pad(image, padding, 0, 'constant')

In [None]:
class ImageDataset(Dataset):
    def __init__(self, dataframe, folder, transform=None):
        self.images = Path(folder) / dataframe["path"].values
        self.img_labels = dataframe["label"].values

        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):

        image = cv2.imread(str(self.images[idx]))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        label = self.img_labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

In [None]:
train_transforms = transforms.Compose([
                transforms.ToTensor(),
                transforms.RandomHorizontalFlip(p=0.3),
                SquarePadTensor(),
                transforms.Resize(224, antialias=True),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])

val_transforms = transforms.Compose([
                transforms.ToTensor(),
                SquarePadTensor(),
                transforms.Resize(224, antialias=True),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])

In [None]:
train_images_full, val_images = train_test_split(df_f, test_size=0.15, random_state=42, stratify=df_f["label"])
# train_images, test_images = train_test_split(train_images_full, test_size=0.10, random_state=42, stratify=train_images_full["label"])

In [None]:
batch_size = 128

train_dataset = ImageDataset(train_images_full, folder="images", transform=train_transforms)
val_dataset = ImageDataset(val_images, folder="images", transform=val_transforms)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [None]:
def train_model(model, criterion, optimizer, train_loader, val_loader, num_epochs=10, last_epoch=0, stage="warmup"):

    for epoch in tqdm(range(last_epoch, num_epochs)):
        model.train()
        epoch_loss = 0

        for data, label in tqdm(train_loader):
            data = data.to(device)
            label = label.to(device)

            optimizer.zero_grad()

            output = model(data)

            loss = criterion(output, label)
            loss.backward()

            optimizer.step()

            epoch_loss += loss / len(train_loader)

        print(f"Epoch : {epoch + 1}, train loss : {epoch_loss}")
        torch.save(model.state_dict(), Path("checkpoints_model") / f"model_checkpoint_{stage}_epoch_{epoch + 1}.pt")


        model.eval()
        with torch.no_grad():
            all_labels = torch.Tensor().to(device)
            all_predictions = torch.Tensor().to(device)
            epoch_val_loss = 0
            for data, label in tqdm(val_loader):
                data = data.to(device)
                label = label.to(device)

                val_output = model(data)
                val_loss = criterion(val_output, label)

                prediction = torch.argmax(val_output.data, axis=1)

                all_labels = torch.cat((all_labels, label))
                all_predictions = torch.cat((all_predictions, prediction))

                epoch_val_loss += val_loss / len(val_loader)

            accuracy = torch.sum(all_predictions == all_labels) / len(all_labels)

            all_labels = all_labels.detach().cpu()
            all_predictions = all_predictions.detach().cpu()
            precision, recall, f1, support = precision_recall_fscore_support(all_labels, all_predictions, average='binary')
            # accuracy = float(accuracy.detach().cpu())
            print(f"Epoch : {epoch + 1}, val_loss : {epoch_val_loss}")
            print(f"val_accuracy : {accuracy}, val_precision : {precision}, val_recall : {recall}, val_f1 : {f1}, val_support : {support}")
        gc.collect()
    return model

In [None]:
from torch.optim.optimizer import Optimizer


class Lion(Optimizer):
  r"""Implements Lion algorithm."""

  def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0):
    """Initialize the hyperparameters.

    Args:
      params (iterable): iterable of parameters to optimize or dicts defining
        parameter groups
      lr (float, optional): learning rate (default: 1e-4)
      betas (Tuple[float, float], optional): coefficients used for computing
        running averages of gradient and its square (default: (0.9, 0.99))
      weight_decay (float, optional): weight decay coefficient (default: 0)
    """

    if not 0.0 <= lr:
      raise ValueError('Invalid learning rate: {}'.format(lr))
    if not 0.0 <= betas[0] < 1.0:
      raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0]))
    if not 0.0 <= betas[1] < 1.0:
      raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1]))
    defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
    super().__init__(params, defaults)

  @torch.no_grad()
  def step(self, closure=None):
    """Performs a single optimization step.

    Args:
      closure (callable, optional): A closure that reevaluates the model
        and returns the loss.

    Returns:
      the loss.
    """
    loss = None
    if closure is not None:
      with torch.enable_grad():
        loss = closure()

    for group in self.param_groups:
      for p in group['params']:
        if p.grad is None:
          continue

        # Perform stepweight decay
        p.data.mul_(1 - group['lr'] * group['weight_decay'])

        grad = p.grad
        state = self.state[p]
        # State initialization
        if len(state) == 0:
          # Exponential moving average of gradient values
          state['exp_avg'] = torch.zeros_like(p)

        exp_avg = state['exp_avg']
        beta1, beta2 = group['betas']

        # Weight update
        update = exp_avg * beta1 + grad * (1 - beta1)
        p.add_(torch.sign(update), alpha=-group['lr'])
        # Decay the momentum running average coefficient
        exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)

    return loss

In [None]:
model = torchvision.models.efficientnet_b0(weights="DEFAULT")
model.train()
pass

In [None]:
for param in model.parameters():
    param.requires_grad = False

model.classifier[1] = nn.Linear(in_features=1280, out_features=2)

model.to(device)
pass

In [None]:
!mkdir checkpoints_model

mkdir: cannot create directory ‘checkpoints_model’: File exists


In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = Lion(params=filter(lambda p: p.requires_grad, model.parameters()))

model = train_model(model, criterion, optimizer, train_dataloader, val_dataloader, num_epochs=5, stage="warmup")

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/820 [00:00<?, ?it/s]

Epoch : 1, train loss : 0.17443671822547913


  0%|          | 0/145 [00:00<?, ?it/s]

Epoch : 1, val_loss : 0.08454915136098862
val_accuracy : 0.9755913615226746, val_precision : 0.9082739386427022, val_recall : 0.9494655004859086, val_f1 : 0.9284130503642698, val_support : None


  0%|          | 0/820 [00:00<?, ?it/s]

Epoch : 2, train loss : 0.08840615302324295


  0%|          | 0/145 [00:00<?, ?it/s]

Epoch : 2, val_loss : 0.07879272103309631
val_accuracy : 0.9787774085998535, val_precision : 0.9188432835820896, val_recall : 0.9572400388726919, val_f1 : 0.9376487386958592, val_support : None


  0%|          | 0/820 [00:00<?, ?it/s]

Epoch : 3, train loss : 0.08171792328357697


  0%|          | 0/145 [00:00<?, ?it/s]

Epoch : 3, val_loss : 0.08494660258293152
val_accuracy : 0.9805594682693481, val_precision : 0.9183798711261123, val_recall : 0.9695497246517655, val_f1 : 0.9432713520327766, val_support : None


  0%|          | 0/820 [00:00<?, ?it/s]

Epoch : 4, train loss : 0.07909845560789108


  0%|          | 0/145 [00:00<?, ?it/s]

Epoch : 4, val_loss : 0.0828813686966896
val_accuracy : 0.9808834791183472, val_precision : 0.9216291268127121, val_recall : 0.9676060900550697, val_f1 : 0.9440581542351455, val_support : None


  0%|          | 0/820 [00:00<?, ?it/s]

Epoch : 5, train loss : 0.07822863012552261


  0%|          | 0/145 [00:00<?, ?it/s]

Epoch : 5, val_loss : 0.07448503375053406
val_accuracy : 0.9806674718856812, val_precision : 0.9276089000313381, val_recall : 0.9588597343699384, val_f1 : 0.9429754698948709, val_support : None


In [None]:
torch.cuda.empty_cache()
gc.collect()

19

In [None]:
# model.load_state_dict(torch.load("checkpoints_model/model_checkpoint_warmup_epoch_5.pt"))

for param in model.parameters():
    param.requires_grad = True

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = Lion(params=model.parameters())

model = train_model(model, criterion, optimizer, train_dataloader, val_dataloader, last_epoch=5, num_epochs=20, stage="train")

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/820 [00:00<?, ?it/s]

Epoch : 6, train loss : 0.044079553335905075


  0%|          | 0/145 [00:00<?, ?it/s]

Epoch : 6, val_loss : 0.031130608171224594
val_accuracy : 0.9913597702980042, val_precision : 0.9641611163970821, val_recall : 0.9847748623258827, val_f1 : 0.9743589743589743, val_support : None


  0%|          | 0/820 [00:00<?, ?it/s]

Epoch : 7, train loss : 0.03164661303162575


  0%|          | 0/145 [00:00<?, ?it/s]

Epoch : 7, val_loss : 0.027341092005372047
val_accuracy : 0.9923318028450012, val_precision : 0.975767366720517, val_recall : 0.9782960803368966, val_f1 : 0.9770300873503721, val_support : None


  0%|          | 0/820 [00:00<?, ?it/s]

KeyboardInterrupt: ignored

# Test

In [None]:
model.load_state_dict(torch.load("checkpoints_model/model_checkpoint_train_epoch_7.pt"))

<All keys matched successfully>

In [None]:
# Testing is done on the different store data (H&M) not used during training

df_orig = pd.read_csv("products.csv")

valid_files_test = []
for hash_1 in Path("test_images").iterdir():
    for hash_2 in hash_1.iterdir():
        for f in hash_2.iterdir():
            valid_files_test.append(f.stem)

df_test = df_orig[df_orig["name_of_shop"] == "h&m"].reset_index(drop=True).copy()
df_test["path"] = df_test["image_hash"].apply(lambda row: f"{row[:2]}/{row[2:4]}/{row}.jpg")

df_test = df_test[df_test["image_hash"].isin(valid_files_test)].reset_index(drop=True).copy()
df_test["label"] = df_test["image_original"].apply(lambda row: 1 if any(map(row.__contains__, ["DESCRIPTIVESTILLLIFE"])) else 0)
df_test

In [None]:
def predict(model, test_dataloader):
    model.eval()
    with torch.no_grad():
        all_labels = torch.Tensor().to(device)
        all_predictions = torch.Tensor().to(device)

        for data, label in tqdm(test_dataloader):
            data = data.to(device)
            label = label.to(device)

            output = model(data)
            prediction = torch.argmax(output.data, axis=1)

            all_labels = torch.cat((all_labels, label))
            all_predictions = torch.cat((all_predictions, prediction))

            torch.cuda.empty_cache()
            gc.collect()

        accuracy = torch.sum(all_predictions == all_labels) / len(all_labels)
        accuracy = float(accuracy.detach().cpu())

        all_labels = all_labels.detach().cpu()
        all_predictions = all_predictions.detach().cpu()
    return accuracy, all_labels, all_predictions

In [None]:
test_dataset = ImageDataset(df_comb, folder="images", transform=val_transforms)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
accuracy, labels, predictions = predict(model, test_dataloader)

print("Accuracy:", accuracy)

  0%|          | 0/166 [00:00<?, ?it/s]

Accuracy: 0.9996228218078613
