# Multi-Label Image Classification on MLRS Net Dataset
Dataset Link: [https://www.kaggle.com/datasets/vigneshwar472/mlrs-net]

In [None]:
!pip install torchmetrics

In [None]:
import torch
import torchvision
import torchmetrics
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset

from sklearn.preprocessing import MultiLabelBinarizer

import os
import ast
import pandas as pd
from PIL import Image
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Dataset Class

In [None]:
train_images_dir = "/content/drive/MyDrive/mlrs_dataset/train/images"
train_labels_csv = "/content/drive/MyDrive/mlrs_dataset/train/train_labels.csv"

val_images_dir = "/content/drive/MyDrive/mlrs_dataset/validation/images"
val_labels_csv = "/content/drive/MyDrive/mlrs_dataset/validation/validation_labels.csv"

test_images_dir = "/content/drive/MyDrive/mlrs_dataset/test/images"
test_labels_csv = "/content/drive/MyDrive/mlrs_dataset/test/test_labels.csv"

In [None]:
class MLRSDataset(Dataset):
    def __init__(self, images_dir, csv_path, image_transforms=None, classes=None):
        self.images_dir = images_dir
        self.image_transorms = image_transforms
        
        # Load and parse the CSV
        self.df = pd.read_csv(csv_path)
        self.df["labels"] = self.df["labels"].apply(ast.literal_eval)
        
        if classes is None:
            self.classes = sorted(list(set([label for sublist in self.df["labels"] for label in sublist])))
        else:
            self.classes = classes

        self.mlb = MultiLabelBinarizer(classes=self.classes)
        self.mlb.fit(self.df["labels"])

        self.image_paths = [os.path.join(self.images_dir, f"{name}.jpg") for name in self.df["image_id"]]

    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, index):
        image_path = self.image_paths[index]

        image = Image.open(image_path).convert("RGB")
        
        if self.image_transorms:
            image = self.image_transorms(image)

        current_labels = self.df["labels"].iloc[index]
        binary_vector = self.mlb.transform([current_labels]).squeeze()
        binary_tensor = torch.from_numpy(binary_vector).float()

        return image, binary_tensor

# Helper Functions