In [1]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from PIL import Image
import pandas as pd
import numpy as np
from torchvision import transforms
from torch import nn
import torch.optim as optim

from tqdm import tqdm
import gdown
import timm  # Import timm for Vision Transformer models
from torchvision.transforms import InterpolationMode
from torchvision.transforms import InterpolationMode
import pytorch_lightning as pl
import torch.nn.functional as F  # Import torch.nn.functional as F
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from torch.optim import lr_scheduler  # Import lr_scheduler correctly
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

device = "cuda:1"




* 'schema_extra' has been renamed to 'json_schema_extra'


In [2]:
# Device configuration
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device = "cuda:1"
print(device)

# Create data directory
os.makedirs("data", exist_ok=True)

# Download dataset using gdown
url = 'https://drive.google.com/uc?id=1IEnpbGjNqXYF4vPY1NW-ODcrYZomyb4S'
output = 'data/dataset.zip'
gdown.download(url, output, quiet=False)

# Unzip dataset
!unzip -q data/dataset.zip -d data

cuda:1


Downloading...
From (original): https://drive.google.com/uc?id=1IEnpbGjNqXYF4vPY1NW-ODcrYZomyb4S
From (redirected): https://drive.google.com/uc?id=1IEnpbGjNqXYF4vPY1NW-ODcrYZomyb4S&confirm=t&uuid=cffc64e7-1736-4b68-b64a-73d6b3fe384b
To: /home/atellezfernandez/git/Juridique/COMPVIZ/data/dataset.zip
 27%|██▋       | 82.3M/310M [00:00<00:01, 116MB/s] 

KeyboardInterrupt: 

A class BirdDataset is defined in order to load all the data from the storage directory (path), the path to the class mapping file (class_mapping_path) and a few transforms to be applied to the images (transforms).

In [3]:
# Define BirdDataset
class BirdDataset(Dataset):
    def __init__(self, path, class_mapping_path, transforms=None):
        super().__init__()
        self.path = path
        self.image_path = list(Path(path).glob("*/*.jpg"))
        self.class_mapping = pd.read_csv(class_mapping_path)
        self.classes = self.class_mapping.sort_values(by='idx').category_cub.to_list()
        self.transforms = transforms

    def __len__(self):
        return len(self.image_path)

    def __getitem__(self, idx):
        image_path = self.image_path[idx]
        target = torch.tensor(self.classes.index(image_path.parent.name))
        with open(image_path, 'rb') as f:
            with Image.open(f) as img:
                image = img.convert('RGB')
        if self.transforms is not None:
            image = self.transforms(image)
        return image, target

Data augmentation is carried out on the training data and the validation data in order to diversify our data and make it more general and therefore more robust. This will help to reduce overfitting later on. More specifically, functions that will allow us to resize, randomly crop (randomresizedcrop) or perform random horizontal inversion (RandomHorizontalFlip).  ----- 

In [4]:
#EASY TRANSFORMATION!! MOST PERFORMANT ONE
train_transforms = transforms.Compose([
    transforms.Resize(256, interpolation=InterpolationMode.BICUBIC),  # Redimensionner le plus petit côté à 256 pixels
    transforms.RandomResizedCrop(224, interpolation=InterpolationMode.BICUBIC),  # Recadrer aléatoirement à 224x224 pixels
    transforms.RandomHorizontalFlip(),
    #transforms.RandomRotation(30),  # Augmenter la rotation aléatoire
    #transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2),
    #transforms.RandomGrayscale(p=0.2),  # Ajouter du RandomGrayscale pour augmenter la variété des données
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

#The descomented transformations are the ones that are a bit more complex, they are worst

# Data augmentations and transformations for validation
val_transforms = transforms.Compose([
    transforms.Resize(256, interpolation=InterpolationMode.BICUBIC),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])


I propose AugMix, a data processing technique that mixes augmented images and enforces consistent embeddings of the augmented images, which results in increased robustness and improved uncertainty calibration. AugMix does not require tuning to work correctly, as with random cropping or CutOut, and thus enables plug-and-play data augmentation. AugMix significantly improves robustness and uncertainty measures on challenging image classification benchmarks, closing the gap between previous methods and the best possible performance by more than half in some cases.

- CLASS AugMixTransform :  custom Albumentations transformation that wraps the augment_and_mix function to integrate it into the transformation pipeline

- class AlbumentationsTransformWrapper : wrapper to make Albumentations transformations compatible with torchvision transformations.

- data_transforms_train_heavy: REALLY HEAVY TRANSFORMATIONS (random cropping, rotation, noise, optical distortion, etc.) + AugMix

In [5]:
import numpy as np
from PIL import Image, ImageEnhance, ImageOps
import random
import albumentations as A

import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
from torch.utils.data import DataLoader
import timm
import torch.nn as nn
import torch

def augmentations_list():
    return [
        A.RandomBrightnessContrast(brightness_limit=0.4, contrast_limit=0.4, p=1.0),
        A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=1.0),
        A.Solarize(p=1.0),
        A.Posterize(num_bits=4, p=1.0),
        A.Equalize(p=1.0),
        A.InvertImg(p=1.0),
        A.Sharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0), p=1.0),
        A.Emboss(alpha=(0.2, 0.5), strength=(0.5, 1.0), p=1.0)
    ]

def normalize(image, mean, std):
    """Normalize input image channel-wise to zero mean and unit variance."""
    image = image.transpose(2, 0, 1).astype(np.float32)  # Switch to channel-first and convert to float32
    mean, std = np.array(mean), np.array(std)
    image = (image - mean[:, None, None]) / std[:, None, None]
    return image.transpose(1, 2, 0).astype(np.float32)

def apply_op(image, op):
    image = np.clip(image * 255., 0, 255).astype(np.uint8)
    pil_img = Image.fromarray(image)  # Convert to PIL.Image
    pil_img = op(image=np.array(pil_img))['image']
    return np.asarray(pil_img) / 255.

def augment_and_mix(image, severity=3, width=3, depth=-1, alpha=1., mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
    """Perform AugMix augmentations and compute mixture.

    Args:
        image: Raw input image as float32 np.ndarray of shape (h, w, c)
        severity: Severity of underlying augmentation operators (between 1 to 10).
        width: Width of augmentation chain
        depth: Depth of augmentation chain. -1 enables stochastic depth uniformly
            from [1, 3]
        alpha: Probability coefficient for Beta and Dirichlet distributions.

    Returns:
        mixed: Augmented and mixed image.
    """
    ws = np.float32(
        np.random.dirichlet([alpha] * width))
    m = np.float32(np.random.beta(alpha, alpha))

    mix = np.zeros_like(image, dtype=np.float32)
    ops = augmentations_list()
    for i in range(width):
        image_aug = image.copy().astype(np.float32)
        d = depth if depth > 0 else np.random.randint(1, 4)
        for _ in range(d):
            op = random.choice(ops)
            image_aug = apply_op(image_aug, op)
        # Preprocessing commutes since all coefficients are convex
        mix += ws[i] * normalize(image_aug, mean, std)

    mixed = (1 - m) * normalize(image, mean, std) + m * mix
    return mixed

# Custom Albumentations transformation for AugMix
class AugMixTransform(A.ImageOnlyTransform):
    def __init__(self, always_apply=False, p=1.0):
        super(AugMixTransform, self).__init__(always_apply, p)
        self.mean = [0.485, 0.456, 0.406]
        self.std = [0.229, 0.224, 0.225]

    def apply(self, img, **params):
        return augment_and_mix(img, mean=self.mean, std=self.std)

# Wrapper class to make Albumentations compatible with torchvision transforms
class AlbumentationsTransformWrapper:
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, img):
        img = np.array(img)
        transformed = self.transform(image=img)
        return transformed["image"]

# Define your heavy data transformations including Albumentations and AugMix
data_transforms_train_heavy = AlbumentationsTransformWrapper(A.Compose([
    A.RandomResizedCrop(224, 224),
    A.HorizontalFlip(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, p=0.5),
    A.OneOf([
        A.GaussNoise(),
        A.ISONoise(),
    ], p=0.2),
    A.OneOf([
        A.MotionBlur(p=0.2),
        A.MedianBlur(blur_limit=3, p=0.1),
        A.Blur(blur_limit=3, p=0.1),
    ], p=0.2),
    A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=45, p=0.2),
    A.OneOf([
        A.OpticalDistortion(p=0.3),
        A.GridDistortion(p=0.1),
        A.PiecewiseAffine(p=0.3),
    ], p=0.2),
    A.OneOf([
        A.Sharpen(p=1.0),
        A.Emboss(p=1.0),
        A.RandomBrightnessContrast(p=1.0),            
    ], p=0.3),
    A.OneOf([
        A.HueSaturationValue(hue_shift_limit=0, sat_shift_limit=10, val_shift_limit=10, p=0.1),
        A.ColorJitter(brightness=0.3, contrast=0.1, saturation=0.1, hue=0.1, p=0.1),
    ]),
    AugMixTransform(p=1.0),
    A.Resize(224, 224, interpolation=cv2.INTER_LANCZOS4),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
]))


- On décide de prendre un batch_size de 32 après avoir essayer de prendre d'autres tailles de batch tel que 16 et 64.
- On a divisé notre dataset en train set et en validation set
- On a décidé d'utiliser les dataloader car ils permettent de diviser automatiquement le jeu de données en lots (batches) de taille spécifiée. Cela simplifie le code et évite d'avoir à implémenter manuellement cette logique. En plus, ils nous permettent de mélanger les données pour réduire l'overfitting lors de l'entrainement du modèle. Ils permettent aussi un chargemnt efficace des données.

(This was chosen on the first model we did)

In [6]:
batch_size = 32

train_dataset = BirdDataset("./data/dataset/cropped_train_images/", "./data/dataset/class_indexes.csv", transforms=data_transforms_train_heavy)
val_dataset = BirdDataset("./data/dataset/cropped_val_images/", "./data/dataset/class_indexes.csv", transforms=val_transforms)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)

In [7]:
import pytorch_lightning as pl
import torch.nn.functional as F  # Import torch.nn.functional as F
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping


nclasses = len(train_dataset.classes)


After testing various models such as ConvNext, ResNet30, ResNet50, EfficientNet, vit_base, and VAE, we decided to retain the vit_large_patch16_224 model due to its superior accuracy. To implement this, the VitNet class was created using PyTorch Lightning. This class inherits from pl.LightningModule, facilitating clear separation of training, validation, and prediction code, thereby enhancing readability and organization.

### Class: VitNet
- **Inherits from:** pl.LightningModule
- **Base model:** Pre-trained Vision Transformer (ViT)

### Functions:

1. **forward(x)**
   - **Purpose:** Defines the data flow through the model.
   - **Process:**
     1. Input data 'x' passes through the ViT model.
     2. Non-linearity applied using ReLU.
     3. Data passes through a linear layer to obtain final predictions.

2. **training_step(batch, batch_idx)**
   - **Purpose:** Calculates the loss during training.
   - **Process:**
     1. Compute loss using cross-entropy on predictions.
     2. Store intermediate results (loss, correct predictions, total targets) in train_outputs list.
   - **Reason:** Cross-entropy is suitable for multi-class classification.

3. **on_train_epoch_end()**
   - **Purpose:** Calculates average training loss and accuracy at the end of each epoch.
   - **Process:**
     1. Calculate average loss from train_outputs.
     2. Compute training accuracy.
     3. Clear train_outputs list for the next epoch.

4. **validation_step(batch, batch_idx)**
   - **Purpose:** Validates the model.
   - **Process:** Similar to training_step, but stores results in val_outputs.

5. **on_validation_epoch_end()**
   - **Purpose:** Calculates average validation loss and accuracy at the end of each epoch.
   - **Process:** Similar to on_train_epoch_end, but uses validation results.

6. **configure_optimizers()**
   - **Purpose:** Configures the optimizer and learning rate scheduler.
   - **Process:**
     1. Use SGD optimizer (preferred over AdamW for better results).
     2. Apply CosineAnnealingLR to update the learning rate.

In [8]:
class VitNet(pl.LightningModule):
    def __init__(self, lr, momentum):
        super(VitNet, self).__init__()
        self.lr = lr
        self.momentum = momentum
        self.vit = timm.create_model('vit_large_patch16_224', pretrained=True)
        num_ftrs = self.vit.head.in_features
        self.vit.head = nn.Identity()
        self.fc = nn.Linear(num_ftrs, nclasses)
        self.train_outputs = []
        self.val_outputs = []

    def forward(self, x):
        x = F.relu(self.vit(x))
        x = self.fc(x)
        return x

    def training_step(self, batch, batch_idx):
        data, target = batch
        output = self(data)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(output, target)
        pred = output.data.max(1, keepdim=True)[1]
        n_correct_pred = pred.eq(target.data.view_as(pred)).sum().detach()
        self.train_outputs.append({'loss': loss, 'n_correct_pred': n_correct_pred, 'n_pred': len(target)})
        self.log('train_loss', loss)  # log the training loss
        return loss

    def on_train_epoch_end(self):
        avg_loss = torch.stack([x['loss'] for x in self.train_outputs]).mean()
        train_acc = sum([x['n_correct_pred'] for x in self.train_outputs]) / sum(x['n_pred'] for x in self.train_outputs)
        print(f"Epoch {self.current_epoch}: Train Loss: {avg_loss:.4f}, Train Accuracy: {train_acc:.4f}")
        self.train_outputs.clear()

    def validation_step(self, batch, batch_idx):
        data, target = batch
        output = self(data)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(output, target)
        pred = output.data.max(1, keepdim=True)[1]
        n_correct_pred = pred.eq(target.data.view_as(pred)).sum().detach()
        self.val_outputs.append({'val_loss': loss, 'n_correct_pred': n_correct_pred, 'n_pred': len(target)})
        self.log('val_loss', loss)  # log the validation loss
        return loss

    def on_validation_epoch_end(self):
        avg_loss = torch.stack([x['val_loss'] for x in self.val_outputs]).mean()
        val_acc = sum([x['n_correct_pred'] for x in self.val_outputs]) / sum(x['n_pred'] for x in self.val_outputs)
        print(f"Epoch {self.current_epoch}: Val Loss: {avg_loss:.4f}, Val Accuracy: {val_acc:.4f}")
        self.val_outputs.clear()

    def configure_optimizers(self):
        optimizer = optim.SGD(self.parameters(), lr=self.lr, momentum=self.momentum)
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=300)
        return [optimizer], [scheduler]

### Callbacks in PyTorch Lightning
- **Purpose:** Add custom behaviors during training.

#### Model Checkpoint
- **Why:** Save the model and its hyperparameters when it achieves the lowest validation loss.
- **Benefits:**
  - Captures the best state of the model before overfitting.
  - Allows comparison of different model versions to assess changes in hyperparameters or structure.

#### Early Stopping
- **Criteria:** Stops training after 5 epochs without improvement (no reduction in validation loss).

### Model Training and Validation
- **Tool:** PyTorch Lightning Trainer
- **Training Duration:** Maximum of 50 epochs

This schematic approach highlights the use of callbacks and the trainer to optimize the training and validation process effectively.

In [10]:
# Training and validation
from torch.optim import lr_scheduler  # Import lr_scheduler correctly
lr = 0.01
momentum = 0.89

model = VitNet(lr, momentum)

# Callbacks
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',  # Change this line
    dirpath='checkpoints',
    filename='best_checkpoint_MIX_new_new',
    save_top_k=1,
    mode='min'
)

early_stopping_callback = EarlyStopping(
    monitor='val_loss',  # Change this line
    patience=5,
    verbose=True,
    mode='min'
)


# Trainer
trainer = pl.Trainer(
    max_epochs=30,       ##################################### that is where i call the hard stopping!!!
    #i forced to be 10
    devices=1 if torch.cuda.is_available() else 0, 
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    precision=16,  # Enable Automatic Mixed Precision
    callbacks=[checkpoint_callback, early_stopping_callback],
    gradient_clip_val=0.5  # Gradient clipping
)

trainer.fit(model, train_loader, val_loader)


Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


DeferredCudaCallError: CUDA call failed lazily at initialization with error: device >= 0 && device < num_gpus INTERNAL ASSERT FAILED at "../aten/src/ATen/cuda/CUDAContext.cpp":50, please report a bug to PyTorch. device=, num_gpus=

CUDA call was originally invoked at:

  File "/home/atellezfernandez/.pyenv/versions/3.10.0/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/atellezfernandez/.pyenv/versions/3.10.0/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/atellezfernandez/.pyenv/versions/3.10.0/envs/theraclion/lib/python3.10/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/home/atellezfernandez/.pyenv/versions/3.10.0/envs/theraclion/lib/python3.10/site-packages/traitlets/config/application.py", line 1053, in launch_instance
    app.start()
  File "/home/atellezfernandez/.pyenv/versions/3.10.0/envs/theraclion/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 736, in start
    self.io_loop.start()
  File "/home/atellezfernandez/.pyenv/versions/3.10.0/envs/theraclion/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 195, in start
    self.asyncio_loop.run_forever()
  File "/home/atellezfernandez/.pyenv/versions/3.10.0/lib/python3.10/asyncio/base_events.py", line 595, in run_forever
    self._run_once()
  File "/home/atellezfernandez/.pyenv/versions/3.10.0/lib/python3.10/asyncio/base_events.py", line 1881, in _run_once
    handle._run()
  File "/home/atellezfernandez/.pyenv/versions/3.10.0/lib/python3.10/asyncio/events.py", line 80, in _run
    self._context.run(self._callback, *self._args)
  File "/home/atellezfernandez/.pyenv/versions/3.10.0/envs/theraclion/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 516, in dispatch_queue
    await self.process_one()
  File "/home/atellezfernandez/.pyenv/versions/3.10.0/envs/theraclion/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 505, in process_one
    await dispatch(*args)
  File "/home/atellezfernandez/.pyenv/versions/3.10.0/envs/theraclion/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 412, in dispatch_shell
    await result
  File "/home/atellezfernandez/.pyenv/versions/3.10.0/envs/theraclion/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 740, in execute_request
    reply_content = await reply_content
  File "/home/atellezfernandez/.pyenv/versions/3.10.0/envs/theraclion/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 422, in do_execute
    res = shell.run_cell(
  File "/home/atellezfernandez/.pyenv/versions/3.10.0/envs/theraclion/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 546, in run_cell
    return super().run_cell(*args, **kwargs)
  File "/home/atellezfernandez/.pyenv/versions/3.10.0/envs/theraclion/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3024, in run_cell
    result = self._run_cell(
  File "/home/atellezfernandez/.pyenv/versions/3.10.0/envs/theraclion/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3079, in _run_cell
    result = runner(coro)
  File "/home/atellezfernandez/.pyenv/versions/3.10.0/envs/theraclion/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
    coro.send(None)
  File "/home/atellezfernandez/.pyenv/versions/3.10.0/envs/theraclion/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3284, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/home/atellezfernandez/.pyenv/versions/3.10.0/envs/theraclion/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3466, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File "/home/atellezfernandez/.pyenv/versions/3.10.0/envs/theraclion/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3526, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_938251/2519443962.py", line 2, in <module>
    import torch
  File "<frozen importlib._bootstrap>", line 1027, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1006, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 688, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 883, in exec_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/home/atellezfernandez/.pyenv/versions/3.10.0/envs/theraclion/lib/python3.10/site-packages/torch/__init__.py", line 1478, in <module>
    _C._initExtension(manager_path())
  File "<frozen importlib._bootstrap>", line 1027, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1006, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 688, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 883, in exec_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/home/atellezfernandez/.pyenv/versions/3.10.0/envs/theraclion/lib/python3.10/site-packages/torch/cuda/__init__.py", line 238, in <module>
    _lazy_call(_check_capability)
  File "/home/atellezfernandez/.pyenv/versions/3.10.0/envs/theraclion/lib/python3.10/site-packages/torch/cuda/__init__.py", line 235, in _lazy_call
    _queued_calls.append((callable, traceback.format_stack()))


In [None]:
torch.save(model.state_dict(), 'best_model_yolo3_MIX.pth')


In [None]:
# Test prediction
def pil_loader(path):
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')

def predict_test(model, test_dir, output_path, transform):
    model.eval()
    output_file = open(output_path, "w")
    output_file.write("ID,category\n")
    for f in tqdm(os.listdir(test_dir)):
        if 'jpg' in f:
            data = transform(pil_loader(os.path.join(test_dir, f)))
            data = data.unsqueeze(0).to(device)
            output = model(data)
            pred = output.data.max(1, keepdim=True)[1]
            output_file.write("%s,%d\n" % (f[:-4], pred.item()))
    output_file.close()

# Load the best model
model.load_state_dict(torch.load('best_model_yolo3_MIX_new_new.pth'))
model.to(device)

# Set the test directory and output path
test_dir = './data/dataset/test_images'
output_path = './yooyo_kaggle_vit_large_yolo3_MIX_new_new.csv'


# Make predictions on the test data
predict_test(model, test_dir, output_path, val_transforms)

100%|██████████| 620/620 [00:07<00:00, 79.91it/s]
