In [None]:
import torch
from torch.utils.data import Dataset
from torch.utils.data.sampler import WeightedRandomSampler
from torch.utils.data import DataLoader
from torchvision import transforms
import os
import torchvision.models as models
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn as nn
import torch.optim as optim

In [None]:

class ImbalancedClassSampler(WeightedRandomSampler):

    def __init__(self,
                 dataset: Dataset,
                 length: int,
                 replacement: bool = True):
        labels = torch.tensor([label for _, label in dataset])

        class_count = torch.bincount(labels.squeeze())
        class_weighting = 1. / class_count
        sample_weights = class_weighting[labels]

        super().__init__(sample_weights, length, replacement=replacement)

In [None]:
from typing import Callable, Optional

from torch.utils.data import Dataset

class Transformed(Dataset):

    def __init__(self, dataset: Dataset, transform: Optional[Callable] = None):
        """Add a transform to a dataset that does not support transforms.

        Args:
            dataset (Dataset): The dataset to use.
            transform (Optional[Callable], optional): Transform to apply to each sample's data. Defaults to None.
        """

        self.transform = transform

        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index: int):
        X, y = self.dataset[index]

        if self.transform is not None:
            X = self.transform(X)

        return X, y

In [None]:
class SegmentedFood(Dataset):
    def __init__(self, data_dir: str, min_samples = 10, transform=None):
        self.data_dir = data_dir
        self.transform = transform

        self.data = []
        self.labels = []
        self.classes = []

        class_names = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))]

        i = 0

        for class_name in class_names:
            class_dir = os.path.join(self.data_dir, class_name)

            files_per_class = []
            labels_per_class = []

            for file_name in os.listdir(class_dir):
                if file_name.endswith(".png"):
                    file_path = os.path.join(class_dir, file_name)
                    files_per_class.append(file_path)
                    labels_per_class.append(i)

            if len(files_per_class) > min_samples:
                self.data.extend(files_per_class)
                self.labels.extend(labels_per_class)
                self.classes.append(class_name)

                i += 1

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        img = Image.open(self.data[idx])
        label = self.labels[idx]

        if self.transform:
            img = self.transform(img)

        return img, label

In [None]:
batch_size = 256
data_dir = "./train_final"

basic_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(256),
])

train_transforms = transforms.Compose([
    basic_transforms,
    transforms.RandomRotation(degrees=(0, 180)),
    transforms.RandomCrop(224),
])

val_transforms = transforms.Compose([
    basic_transforms,
    transforms.CenterCrop(224)
])


dataset = SegmentedFood(data_dir, 100)

train_dataset, val_dataset = torch.utils.data.random_split(dataset, [0.8, 0.2])
train_sampler = ImbalancedClassSampler(train_dataset, int(8600 * 0.8))
val_sampler = ImbalancedClassSampler(val_dataset, int(8600 * 0.2))
train_dataset = Transformed(train_dataset, train_transforms)
val_dataset = Transformed(val_dataset, val_transforms)

train_data_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler)
val_data_loader = DataLoader(val_dataset, batch_size=batch_size,sampler=val_sampler)

In [None]:
resnet50 = models.resnet50(weights=models.ResNet50_Weights)

num_classes = len(dataset.classes)

for param in resnet50.parameters():
    param.requires_grad = False

resnet50.fc = nn.Linear(resnet50.fc.in_features, num_classes)

In [None]:
model_save_dir_path = f"nets/"
if not os.path.exists(model_save_dir_path):
    os.makedirs(model_save_dir_path)
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
resnet50.to(device)
loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(resnet50.parameters(), lr=0.0002)

In [None]:
for epoch in range(60):
    # train
    train_error = 0.0
    train_accuracy = 0.0
    with torch.set_grad_enabled(True):
        resnet50.train()
        for batch_idx, (data, target) in enumerate(train_data_loader):
            # # # break after  first batch
            # if batch_idx == 1:
            #     break
            optimizer.zero_grad()
            output = resnet50(data)
            prediction = output.argmax(dim=1).view(target.shape)
            accuracy = (prediction == target).sum().float() / target.size(0)
            error = loss(output, target)

            error.backward()
            optimizer.step()

            train_error += error.detach()
            train_accuracy += accuracy

        print("train loss", train_error.item() / len(train_data_loader), "train acc", train_accuracy.item() / len(train_data_loader))

    # save model
    torch.save(resnet50.state_dict(), os.path.join(model_save_dir_path, f"resnet50_{epoch}.pth"))
    
    # validate
    with torch.set_grad_enabled(False):
        resnet50.eval()
        eval_error = 0.0
        eval_accuracy = 0.0
        preds = []
        labels = []
        for batch_idx, (data, target) in enumerate(val_data_loader):
            # break after  first batch
            # if batch_idx == 1:
            #     break
            output = resnet50(data)
            prediction = output.argmax(dim=1).view(target.shape)
            accuracy = (prediction == target).sum().float() / target.size(0)
            error = loss(output, target)

            eval_error += error.detach()
            eval_accuracy += accuracy

            preds.append(prediction)
            labels.append(target)

        preds = torch.cat(preds)
        labels = torch.cat(labels)

        print("eval loss", (eval_error.item() / len(val_data_loader), "train acc", eval_accuracy.item() / len(val_data_loader)))

