In [1]:
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
from PIL import Image
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from nn_utils import WeightedBinaryCrossEntropyLoss
from sklearn.metrics import (
    precision_score, recall_score,
    roc_auc_score, average_precision_score,
    hamming_loss, accuracy_score,
    coverage_error, f1_score
)

In [2]:
# Read the CSV file
csv_path = 'datasets/all_one_hot.csv'
data = pd.read_csv(csv_path, index_col=0)
# Extract image paths and labels
file_paths = data['file_path'].tolist()
_labels_df = data.drop(columns=['file_path', 'labels'])
labels = _labels_df.values
class_names = _labels_df.columns.tolist()  # Save the class names
num_classes = len(class_names)

# Save class names to a file for future reference
with open('class_names.txt', 'w') as f:
    for class_name in class_names:
        f.write(f"{class_name}\n")

print(f"Class names: {class_names}")

# Calculate the number of occurrences for each class
class_counts = np.sum(labels, axis=0)
total_samples = labels.shape[0]

# Calculate the weights
class_weights = total_samples / (num_classes * class_counts)

# Convert to a PyTorch tensor
pos_weight = torch.tensor(class_weights, dtype=torch.float32)
print(f"Class counts: {class_counts}")
print(f"Class weights: {class_weights}")

Class names: ['anal sex', 'anilingus', 'asian', 'ass', 'bath', 'bdsm', 'beach', 'big breasts', 'big woman', 'bikini', 'black penis', 'blonde', 'blowjob', 'bondage', 'boots', 'brunette', 'butt plug', 'chubby', 'clothed', 'cosplay', 'cowgirl (sex position)', 'creampie', 'cum in mouth', 'cumshot', 'cunnilingus', 'curly', 'curvy', 'dildo', 'doggy style (sex position)', 'double penetration', 'dress', 'ebony', 'facial', 'feet', 'fellatio', 'fisting', 'footjob', 'glasses', 'granny', 'group sex', 'gym', 'hairy vulva', 'handjob', 'heels', 'intercourse', 'interracial', 'jeans', 'kissing', 'latina', 'leather', 'lesbian', 'lingerie', 'maid', 'masturbation', 'mature', 'milf', 'missionary (sex position)', 'naked breasts', 'nurse', 'outdoor', 'panties', 'pantyhose', 'petite', 'pissing', 'pool', 'public', 'redhead', 'sandals', 'selfie', 'sex toys', 'shorts', 'shower', 'skinny', 'skirt', 'small breasts', 'smoking', 'socks', 'sports', 'stockings', 'tattoo', 'teacher', 'teen', 'thick', 'thong', 'threesom

In [3]:
class AlbumDataset(Dataset):
    def __init__(self, file_paths, labels, transform=None, max_seq_len=16):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform
        self.max_seq_len = max_seq_len

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        album_paths = self.file_paths[idx]
        label = torch.tensor(self.labels[idx], dtype=torch.float32)

        # Apply transformations
        processed_album = []
        for img_path in album_paths:
            img = Image.open(img_path).convert('RGB')
            if self.transform:
                img = self.transform(img)
            processed_album.append(img)

        # Padding if necessary
        if len(processed_album) < self.max_seq_len:
            padding = [torch.zeros_like(processed_album[0])] * (self.max_seq_len - len(processed_album))
            processed_album.extend(padding)

        processed_album = torch.stack(processed_album)
        length = min(len(album_paths), self.max_seq_len)

        return processed_album, label, length

In [4]:
# Data augmentation and preprocessing transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create dataset and dataloader
dataset = AlbumDataset(file_paths, labels, transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=lambda x: (
    torch.stack([item[0] for item in x]),
    torch.stack([item[1] for item in x]),
    torch.tensor([item[2] for item in x])
))

In [5]:
class MultilabelAlbumModel(nn.Module):
    def __init__(self, num_classes, max_seq_len):
        super(MultilabelAlbumModel, self).__init__()
        self.max_seq_len = max_seq_len
        self.cnn = models.resnet50(pretrained=True)
        self.cnn.fc = nn.Identity()
        self.rnn = nn.LSTM(input_size=2048, hidden_size=512, num_layers=2, batch_first=True, bidirectional=True)
        self.fc1 = nn.Linear(512 * 2, 256)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x, lengths):
        batch_size, seq_len, c, h, w = x.size()
        cnn_features = []
        for t in range(seq_len):
            img = x[:, t, :, :, :]
            feature = self.cnn(img)
            cnn_features.append(feature)
        cnn_features = torch.stack(cnn_features, dim=1)
        packed_input = nn.utils.rnn.pack_padded_sequence(cnn_features, lengths, batch_first=True, enforce_sorted=False)
        packed_output, (h_n, c_n) = self.rnn(packed_input)
        rnn_out, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True, total_length=self.max_seq_len)
        idx = (lengths - 1).view(-1, 1).expand(len(lengths), rnn_out.size(2)).unsqueeze(1)
        last_output = rnn_out.gather(1, idx).squeeze(1)
        out = self.fc1(last_output)
        out = self.fc2(out)
        return torch.sigmoid(out)

In [None]:
def calculate_metrics(outputs, targets):
    outputs = outputs.detach().cpu().numpy()
    targets = targets.detach().cpu().numpy()
    outputs = (outputs > 0.5).astype(int)

    precision = precision_score(targets, outputs, average='micro')
    recall = recall_score(targets, outputs, average='micro')
    macro_f1 = f1_score(targets, outputs, average='macro')
    micro_f1 = f1_score(targets, outputs, average='micro')
    roc_auc = roc_auc_score(targets, outputs, average='micro')
    pr_auc = average_precision_score(targets, outputs, average='micro')
    hamming = hamming_loss(targets, outputs)
    subset_acc = accuracy_score(targets, outputs)
    coverage = coverage_error(targets, outputs)

    return {
        "precision": precision,
        "recall": recall,
        "macro_f1": macro_f1,
        "micro_f1": micro_f1,
        "roc_auc": roc_auc,
        "pr_auc": pr_auc,
        "hamming_loss": hamming,
        "subset_accuracy": subset_acc,
        "coverage_error": coverage
    }


# Example training loop with additional metrics
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_epochs = 10
max_seq_len = 24
model = MultilabelAlbumModel(num_classes=num_classes, max_seq_len=max_seq_len)
criterion = WeightedBinaryCrossEntropyLoss(pos_weight=pos_weight)
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.to(device)

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    all_outputs = []
    all_labels = []

    for batch in dataloader:
        images, labels, lengths = batch
        images = images.to(device)
        labels = labels.to(device)
        lengths = lengths.to(device)

        optimizer.zero_grad()
        outputs = model(images, lengths)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        all_outputs.append(outputs)
        all_labels.append(labels)

    epoch_loss = running_loss / len(dataloader)
    all_outputs = torch.cat(all_outputs)
    all_labels = torch.cat(all_labels)

    metrics = calculate_metrics(all_outputs, all_labels)

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}, Micro F1: {metrics["micro_f1"]:.4f}')
    print(
        f"Precision: {metrics['precision']:.4f}, Recall: {metrics['recall']:.4f}, Macro F1: {metrics['macro_f1']:.4f}")
    print(
        f"ROC AUC: {metrics['roc_auc']:.4f}, PR AUC: {metrics['pr_auc']:.4f}, Hamming Loss: {metrics['hamming_loss']:.4f}")
    print(f"Subset Accuracy: {metrics['subset_accuracy']:.4f}, Coverage Error: {metrics['coverage_error']:.4f}")