In [49]:
import os
from typing import final
from pathlib import Path
from PIL import Image
import pandas as pd
import timm

import torchvision
from torchvision import transforms

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split

device = "cuda" if torch.cuda.is_available() else "cpu"

device

'cuda'

In [41]:
@final
class CustomDataset(Dataset):
    def __init__(self, class_to_label_csv: Path, images_dir: Path, has_labels=True, transform=None):
        self.class_to_label = pd.read_csv(class_to_label_csv)
        self.images_dir = images_dir
        self.images = [file_name for file_name in os.listdir(images_dir)]
        self.transform = transform if transform else transforms.ToTensor()
        self.has_labels = has_labels

        self.classes = sorted(self.class_to_label.iloc[:, 0].unique())
        self.class_to_idx = {row["name"]: row["label"] for idx, row in self.class_to_label.iterrows()}

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        file_name = self.images[idx]
        img_path = os.path.join(self.images_dir, file_name)
        image = Image.open(img_path).convert("RGB")
        image = self.transform(image)

        if self.has_labels:
            cls = file_name.split("_")[0]
            label = self.class_to_idx[cls]
            return image, label
        else:
            return image, file_name

In [50]:
BATCH_SIZE = 32

full_train = CustomDataset(
    class_to_label_csv=Path("./animal/class_names.csv"), 
    images_dir=Path("./animal/train/"), 
    has_labels=True
)

train_size = int(0.8 * len(full_train))
val_size = len(full_train) - train_size

train_dataset, val_dataset = random_split(full_train, [train_size, val_size])

train_dataloader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, BATCH_SIZE, shuffle=False)

test = CustomDataset(
    class_to_label_csv=Path("./animal/class_names.csv"), 
    images_dir=Path("./animal/test_new/"), 
    has_labels=False
)

test_dataloader = DataLoader(test, BATCH_SIZE, shuffle=False)

In [43]:
model = timm.create_model(
    "tf_efficientnetv2_s.in21k", # efficientNet imNet21k
    pretrained=True,
    num_classes=14
)

In [44]:
model.classifier

Linear(in_features=1280, out_features=14, bias=True)

In [45]:
for param in model.parameters():
    param.requires_grad = False

for param in model.classifier.parameters():
    param.requires_grad = True

In [46]:
loss_fn = torch.nn.CrossEntropyLoss()

optim = torch.optim.Adam(params=model.parameters(), lr=1e-3)

epochs = 10

In [None]:
from misc import train_model

res = train_model(
    model, 
    train_dataloader,
    val_dataloader,
    loss_fn,
    optim,
    device,
    epochs
)

  0%|          | 0/10 [00:00<?, ?it/s]