In [None]:
import os

from typing import Iterator
from tqdm import tqdm

import torch
import torch.optim as optim
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms

from PIL import Image
import numpy as np
import time

# Download and Extract dataset 
11GB 

In [None]:
from huggingface_hub import hf_hub_download
hf_hub_download(repo_id='RayanAi/inat_train_modified', filename="inat_train_modified.tar.gz", repo_type="dataset", local_dir=".")

In [None]:
!tar xfz inat_train_modified.tar.gz -C .

In [None]:
class Node:
    def __init__(self, name):
        self.name = name
        self._count = 0
        self.children = {}
        self._entities = []

    def add_to_node(self, path, entity, level=0):
        if level >= len(path):
            self._entities.append(entity)
            return
        part = path[level]
        if part not in self.children:
            self.children[part] = Node(path[:level+1])
        self.children[part].add_to_node(path, entity, level=level+1)
        self._count += 1

    @property
    def is_leaf(self):
        return len(self._entities) > 0

    @property
    def count(self):
        if self.is_leaf:
            return len(self._entities)
        else:
            return self._count

    @property
    def entities(self):
        if self.is_leaf:
            return list((entity, self.name) for entity in self._entities)
        else:
            child_entities = []
            for child in self.children.values():
                child_entities.extend(child.entities)
        return child_entities

    def level_iterator(self, level=None):
        """
        iterates a certain depth in a tree and returns the nodes
        """
        if level == 0:
            yield self
        elif level == None and self.is_leaf:
            yield self
        elif self.is_leaf and level != 0:
            raise Exception("Incorrect level is specified in tree.")
        else:
            if level is not None:
                level -= 1
            for child in self.children.values():
                for v in child.level_iterator(level):
                    yield v


    def print_node(self, level=0, max_level=None):
        leaves = 1
        print(' ' * (level * 4) + f"{self.name[-1]} ({self.count})")
        for node in self.children.values():
            if max_level is None or level < max_level:
                leaves += node.print_node(level + 1, max_level=max_level)
        return leaves

In [None]:
class HiererchicalDataset(Dataset):
    def __init__(self, dataset_path, level=None):
        self.tree = Node("Dataset") # keeps the group information of self.data in a tree (per index).
        self.level = level
        if level is None:
            self.level = 7  # Hardcoded
        self.classes = set()
        data = []
        index = 0
        for group_name in sorted(os.listdir(dataset_path)):
            if not os.path.isdir(os.path.join(dataset_path, group_name)):
                continue
            for image_name in sorted(os.listdir(os.path.join(dataset_path, group_name))):
                group = tuple(group_name.split("_")[1:])
                image_path = os.path.join(dataset_path, group_name, image_name)
                data.append({
                        "image_path": image_path,
                        "group": group,
                    }
                )
                self.tree.add_to_node(group, index)
                index += 1
                self.classes.add(group[:self.level])
        self.data = data
        self.classes = {group: index for (index, group) in enumerate(sorted(list(self.classes)))}
        self.transform = transforms.Compose([
            transforms.Resize(224),
            transforms.RandomCrop(224, padding=4),
            transforms.ToTensor(),
            transforms.Normalize((0.4556, 0.4714, 0.3700), (0.2370, 0.2318, 0.2431))
        ])
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = Image.open(self.data[idx]["image_path"])
        target = self.classes[self.data[idx]["group"][:self.level]]
        if self.transform:
            image = self.transform(image)

        return image, target

    def get_group_iterator(self, level=None) -> Iterator[Node]:
        for group in self.tree.level_iterator(level):
            yield group

In [None]:
test = HiererchicalDataset('train', level=2)
print("Dataset Length:", f"{len(test)}")
test.tree.print_node(max_level=2)
print(test.classes)

# Augment and Undersampling

In [None]:
from collections import defaultdict
from pathlib import Path

AUGMENTATION_TRANSFORMATION = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=30),
    transforms.ColorJitter(brightness=0.3, contrast=0.3),
    transforms.GaussianBlur(kernel_size=(5, 1), sigma=(0.001, 5)),
])

class AugmentativeDataset(Dataset):
    def __init__(self, original_dataset, output_dir, target_size=10000):
        self.original_dataset = original_dataset
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)

        class_counts = defaultdict(int)
        for item in original_dataset.data:
            group = item["group"][:original_dataset.level]
            class_counts[group] += 1

        print("Original class distribution:")
        for group, count in class_counts.items():
            print(f"Class {group}: {count} samples")

        self.balanced_data = []
        for group, count in class_counts.items():
            group_data = [item for item in original_dataset.data if item["group"][:original_dataset.level] == group]
            group_dir = self.output_dir / "_".join(group)
            group_dir.mkdir(parents=True, exist_ok=True)

            if count < target_size:
                required_samples = target_size - count
                print(f"Generating {required_samples} samples for class {group}")

                for _ in tqdm(range(required_samples)):
                    sample = random.choice(group_data)
                    image = Image.open(sample["image_path"])
                    augmented_image = AUGMENTATION_TRANSFORMATION(image)

                    filename = f"{group}_{len(self.balanced_data)}.png"
                    filepath = group_dir / filename
                    augmented_image.save(filepath)
                    
                    self.balanced_data.append({
                        "image_path": filepath,
                        "group": sample["group"]
                    })

            elif count > target_size:
                print(f"Sampling down {count - target_size} samples for class {group}")
                sampled_data = random.sample(group_data, target_size)
                self.balanced_data.extend(sampled_data)
            else:
                self.balanced_data.extend(group_data)

        print("Data balancing complete.")
        print(f"Total samples after balancing: {len(self.balanced_data)}")

        self.transform = original_dataset.transform

    def __len__(self):
        return len(self.balanced_data)

    def __getitem__(self, idx):
        image = Image.open(self.balanced_data[idx]["image_path"])
        target = self.original_dataset.classes[self.balanced_data[idx]["group"][:self.original_dataset.level]]
        if self.transform:
            image = self.transform(image)
        return image, target



In [None]:
! rm -rf balanced_train
dataset_path = 'train'
balanced_output_dir = 'balanced_train'
train_dataset = HiererchicalDataset(dataset_path=dataset_path, level=2)
balanced_dataset =  AugmentativeDataset(train_dataset, balanced_output_dir)

print("Balanced Dataset Length:", len(balanced_dataset))

In [None]:
import glob

class BalancedDataset(Dataset):
    def __init__(self, level=None):
        self.tree = Node("Dataset")
        self.level = level
        if level is None:
            self.level = 7
        self.classes = set()
        data = []
        index = 0
        dataset_path = 'balanced_train'
        for group_name in sorted(os.listdir(dataset_path)):
            if not os.path.isdir(os.path.join(dataset_path, group_name)):
                continue

            files = os.listdir(os.path.join(dataset_path, group_name))
            is_balanced = 1
            dst_path = dataset_path
            if len(files) == 0:
                files = glob.glob(os.path.join('train', f'*_{group_name}_*/*.jpg'))
                is_balanced = 0
                dst_path = 'train'
                if len(files) > 10000:
                    files = random.sample(files, 10000)
                
            for image_name in sorted(files):
                if is_balanced == 0:
                    group_name = os.path.dirname(image_name).split('/')[-1]
                    group = tuple(group_name.split("_")[1:])
                    image_path = image_name
                else:
                    group = tuple(group_name.split("_")[:])
                    image_path = os.path.join(dst_path, group_name, image_name)

                
                
                data.append({
                        "image_path": image_path,
                        "group": group,
                    }
                )
                self.tree.add_to_node(group, index)
                index += 1
                self.classes.add(group[:self.level])
        self.data = data
        self.classes = {group: index for (index, group) in enumerate(sorted(list(self.classes)))}
        self.transform = transforms.Compose([
            transforms.Resize(224),
            transforms.RandomCrop(224, padding=4),
            transforms.ToTensor(),
            transforms.Normalize((0.4556, 0.4714, 0.3700), (0.2370, 0.2318, 0.2431))
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = Image.open(self.data[idx]["image_path"])
        target = self.classes[self.data[idx]["group"][:self.level]]
        if self.transform:
            image = self.transform(image)

        return image, target

    def get_group_iterator(self, level=None) -> Iterator[Node]:
        for group in self.tree.level_iterator(level):
            yield group

In [None]:
train_dataset = BalancedDataset(level=2)
print("Dataset Length:", f"{len(train_dataset)}")
train_dataset.tree.print_node(max_level=2)
print(train_dataset.classes)

# Model

In [None]:
from huggingface_hub import hf_hub_download
hf_hub_download(repo_id='RayanAi/resnet50-pretrained-inat', filename="resnet50.pth", repo_type="model", local_dir=".")

In [None]:
model = models.resnet50(pretrained=False)

model.fc = nn.Linear(model.fc.in_features, len(train_dataset.classes.keys()))

model.load_state_dict(torch.load("resnet50.pth"))


################### OPTIONAL #########################
model.requires_grad_(True)
model.fc.requires_grad_(True)

# Train

In [None]:
! rm -rf checkpoints
!mkdir -p checkpoints

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
experiment_name = "simple-training"
print("Experiment {}".format(experiment_name))

# Hyperparameters
learning_rate = 0.01
num_epochs = 20
batch_size = 256
checkpoint_dir = "./checkpoints"


print(f"Dataset Length: {len(train_dataset)}, Batch size: {batch_size}, LR: {learning_rate}")
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
# optimizer = optim.SGD(model.parameters(), lr=learning_rate)
optimizer = optim.SGD(model.fc.parameters(), lr=learning_rate)

model = model.to(device)

def save_checkpoint(state, filename):
    torch.save(state, filename)


def train(model, train_loader, criterion, optimizer, device, epoch, num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    epoch_start_time = time.time()

    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}", leave=False)

    for inputs, labels in progress_bar:
        inputs, labels = inputs.to(device), labels.to(device)

        outputs = model(inputs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Statistics
        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        # Update tqdm progress bar description
        progress_bar.set_postfix(loss=f"{loss.item():.4f}", accuracy=f"{100. * correct / total:.2f}%")

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    epoch_time = time.time() - epoch_start_time

    print(f"Epoch {epoch} | Training Loss: {epoch_loss:.4f} | Training Accuracy: {epoch_acc:.4f} | Time: {epoch_time:.2f}s")
    return epoch_loss, epoch_acc

best_acc = 0.0
print(f"Starting training for {num_epochs} epochs.")

# Main training loop
for epoch in range(1, num_epochs + 1):
    train_loss, train_acc = train(model, train_loader, criterion, optimizer, device, epoch, num_epochs)


    # Save the latest model
    latest_checkpoint_path = os.path.join(checkpoint_dir, 'latest_checkpoint.pth')
    save_checkpoint({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'train_acc': train_acc,
    }, latest_checkpoint_path)

    print(f"Epoch [{epoch}/{num_epochs}] Summary: "
                 f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")

    # Save the best model based on test accuracy
    # if train_acc > best_acc:
    #     best_acc = train_acc
    #     best_checkpoint_path = os.path.join(checkpoint_dir, 'best_checkpoint.pth')
    #     save_checkpoint({
    #         'epoch': epoch,
    #         'model_state_dict': model.state_dict(),
    #         'optimizer_state_dict': optimizer.state_dict(),
    #         'train_loss': train_loss,
    #         'train_acc': train_acc,
    #     }, best_checkpoint_path)
    #     print(f"New best model saved with accuracy: {best_acc:.4f}")



In [None]:
import zipfile
torch.save(model.state_dict(), 'resnet.pth')
with zipfile.ZipFile('submission.zip', 'w') as zipf:
    zipf.write('resnet.pth')