In [None]:
import ast
import torch
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import torch.nn as nn
from utils import SRC_DIR
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.models import ResNet50_Weights
from torch.utils.data import Dataset, DataLoader
from nn_utils import WeightedBinaryCrossEntropyLoss
from sklearn.metrics import (
    precision_score, recall_score,
    roc_auc_score, average_precision_score,
    hamming_loss, accuracy_score,
    coverage_error, f1_score
)

In [None]:
csv_path = './datasets/cropped_all_one_hot.csv'
data = pd.read_csv(csv_path, index_col=0)
# convert string labels to list
data['labels'] = data['labels'].apply(ast.literal_eval)
data = data[data.labels.apply(lambda x: len(x) > 5)]
# Extract image paths and labels
data['file_names'] = data['file_path'].apply(lambda x: x.split("/")[-1]).tolist()
data['galleries'] = data['file_path'].apply(lambda x: x.split("/")[-3]).tolist()
data.drop(columns=['file_path', 'labels'], inplace=True)
data.shape

In [None]:
labels = []
file_paths = []

gallery_names = data['galleries'].unique()

grouped_data = data.groupby('galleries')

for gallery in tqdm(gallery_names, total=len(gallery_names), desc='Processing galleries'):
    temp_df = grouped_data.get_group(gallery)
    file_paths.append(temp_df['file_names'].tolist())
    temp_labels = temp_df.drop(columns=['file_names', 'galleries']).sum(axis=0).values
    labels.append((temp_labels > 0).astype(int))

In [None]:
labels = np.array(labels)
class_names = data.drop(['file_names', 'galleries'], axis=1).columns.tolist()  # Save the class names
num_classes = len(class_names)

# Save class names to a file for future reference
with open('class_names.txt', 'w') as f:
    for class_name in class_names:
        f.write(f"{class_name}\n")

print(f"Class names: {class_names}")

In [None]:
# Calculate the number of occurrences for each class
class_counts = np.sum(labels, axis=0)
total_samples = labels.shape[0]

# Calculate the weights
class_weights = total_samples / (num_classes * class_counts)

# Convert to a PyTorch tensor
pos_weight = torch.tensor(class_weights, dtype=torch.float32)
print(f"Class counts: {class_counts}")
print(f"Class weights: {class_weights}")

In [None]:
class AlbumDataset(Dataset):
    def __init__(self, file_paths, labels, transform=None, max_seq_len=16):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform
        self.max_seq_len = max_seq_len

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        album_paths = self.file_paths[idx]
        label = torch.tensor(self.labels[idx], dtype=torch.float32)

        # Apply transformations
        processed_album = []
        for img_path in album_paths:
            img = Image.open(SRC_DIR / img_path).convert('RGB')
            if self.transform:
                img = self.transform(img)
            processed_album.append(img)

        # Padding if necessary
        if len(processed_album) < self.max_seq_len:
            padding = [torch.zeros_like(processed_album[0])] * (self.max_seq_len - len(processed_album))
            processed_album.extend(padding)
        else:
            processed_album = processed_album[:self.max_seq_len]  # Trim if too long

        processed_album = torch.stack(processed_album)
        length = min(len(album_paths), self.max_seq_len)

        return processed_album, label, length

In [None]:
def collate_fn(batch):
    images, labels, lengths = zip(*batch)

    # Pad sequences to the maximum length in this batch
    max_len = max(lengths)
    padded_images = []
    for img_seq in images:
        if len(img_seq) < max_len:
            padding = [torch.zeros_like(img_seq[0])] * (max_len - len(img_seq))
            img_seq = torch.cat([img_seq, torch.stack(padding)])
        padded_images.append(img_seq)

    padded_images = torch.stack(padded_images)
    labels = torch.stack(labels)
    lengths = torch.tensor(lengths, dtype=torch.long)

    return padded_images, labels, lengths


# Data augmentation and preprocessing transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset = AlbumDataset(file_paths, labels, transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)

In [None]:
class MultilabelAlbumModel(nn.Module):
    def __init__(self, num_classes, max_seq_len, lstm_hidden_size=512, lstm_layers=2):
        super(MultilabelAlbumModel, self).__init__()
        self.max_seq_len = max_seq_len
        self.cnn = models.resnet50(weights=ResNet50_Weights.DEFAULT)
        self.cnn.fc = nn.Identity()
        self.rnn = nn.LSTM(input_size=2048, hidden_size=lstm_hidden_size, num_layers=lstm_layers, batch_first=True)
        self.fc1 = nn.Linear(lstm_hidden_size * 2, 256)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x, lengths):
        batch_size, seq_len, c, h, w = x.size()
        cnn_features = []
        for t in range(seq_len):
            img = x[:, t, :, :, :]
            feature = self.cnn(img)
            cnn_features.append(feature)
        cnn_features = torch.stack(cnn_features, dim=1)
        packed_input = nn.utils.rnn.pack_padded_sequence(cnn_features, lengths, batch_first=True, enforce_sorted=False)
        packed_output, (h_n, c_n) = self.rnn(packed_input)
        rnn_out, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True, total_length=self.max_seq_len)
        idx = (lengths - 1).view(-1, 1).expand(len(lengths), rnn_out.size(2)).unsqueeze(1)
        last_output = rnn_out.gather(1, idx).squeeze(1)
        out = self.fc1(last_output)
        out = self.fc2(out)
        return torch.sigmoid(out)

In [None]:
def calculate_metrics(outputs, targets):
    outputs = outputs.detach().cpu().numpy()
    targets = targets.detach().cpu().numpy()
    outputs = (outputs > 0.5).astype(int)

    precision = precision_score(targets, outputs, average='micro')
    recall = recall_score(targets, outputs, average='micro')
    macro_f1 = f1_score(targets, outputs, average='macro')
    micro_f1 = f1_score(targets, outputs, average='micro')
    roc_auc = roc_auc_score(targets, outputs, average='micro')
    pr_auc = average_precision_score(targets, outputs, average='micro')
    hamming = hamming_loss(targets, outputs)
    subset_acc = accuracy_score(targets, outputs)
    coverage = coverage_error(targets, outputs)

    return {
        "precision": precision,
        "recall": recall,
        "macro_f1": macro_f1,
        "micro_f1": micro_f1,
        "roc_auc": roc_auc,
        "pr_auc": pr_auc,
        "hamming_loss": hamming,
        "subset_accuracy": subset_acc,
        "coverage_error": coverage
    }

In [None]:
# Example training loop
num_epochs = 10
max_seq_len = 16
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = MultilabelAlbumModel(num_classes=num_classes, max_seq_len=max_seq_len)
model.to(device)

# Custom loss function
criterion = WeightedBinaryCrossEntropyLoss(pos_weight=pos_weight)

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    all_outputs = []
    all_labels = []

    # Initialize progress bar for each epoch
    with tqdm(total=len(dataloader), desc=f'Epoch {epoch+1}/{num_epochs}', unit='batch') as pbar:
        for batch in dataloader:
            images, labels, lengths = batch
            images = images.to(device)
            labels = labels.to(device)
            lengths = lengths.to(device)

            optimizer.zero_grad()
            outputs = model(images, lengths)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            all_outputs.append(outputs)
            all_labels.append(labels)
            
            # Update progress bar
            pbar.update(1)

    epoch_loss = running_loss / len(dataloader)
    all_outputs = torch.cat(all_outputs)
    all_labels = torch.cat(all_labels)

    metrics = calculate_metrics(all_outputs, all_labels)

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}, Micro F1: {metrics["micro_f1"]:.4f}')
    print(f"Precision: {metrics['precision']:.4f}, Recall: {metrics['recall']:.4f}, Macro F1: {metrics['macro_f1']:.4f}")
    print(f"ROC AUC: {metrics['roc_auc']:.4f}, PR AUC: {metrics['pr_auc']:.4f}, Hamming Loss: {metrics['hamming_loss']:.4f}")
    print(f"Subset Accuracy: {metrics['subset_accuracy']:.4f}, Coverage Error: {metrics['coverage_error']:.4f}")