In [None]:
# Import necessary libraries
import torch
import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torchvision import transforms, models
from transformers import DistilBertTokenizer, DistilBertModel
from sklearn.metrics import confusion_matrix, accuracy_score, recall_score, precision_score, classification_report
import seaborn as sns
import matplotlib.pyplot as plt
from torch.optim import Adam
from torch.optim.lr_scheduler import ExponentialLR
from IPython.display import display, Image as IPImage
from google.colab import drive


In [None]:
# Paths for train, validation, and test datasets
TRAIN_PATH = r"/work/TALC/enel645_2024f/garbage_data/CVPR_2024_dataset_Train"
VAL_PATH = r"/work/TALC/enel645_2024f/garbage_data/CVPR_2024_dataset_Val"
TEST_PATH = r"/work/TALC/enel645_2024f/garbage_data/CVPR_2024_dataset_Test"

# Define class names
class_names = ["Red", "Blue", "Black", "TTR"]


# Pre-process the data
# Transformations for the images
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to 224x224
    transforms.RandomHorizontalFlip(),  # Data augmentation
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize with ImageNet statistics
])

Image Transformation

In [None]:
# Transformations for the images
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to 224x224
    transforms.RandomHorizontalFlip(),  # Data augmentation
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize with ImageNet statistics
])

# Custom dataset class for loading images and text
class CustomImageTextDataset(Dataset):
    def __init__(self, image_dir, tokenizer, max_len, image_transform=None):
        self.image_paths = []
        self.texts = []
        self.labels = []
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.image_transform = image_transform

        class_folders = sorted(os.listdir(image_dir))
        label_map = {class_name: idx for idx, class_name in enumerate(class_folders)}
        self.label_map = label_map

        for class_name in class_folders:
            class_path = os.path.join(image_dir, class_name)
            if os.path.isdir(class_path):
                for file_name in os.listdir(class_path):
                    if file_name.endswith(('.png', '.jpg')):
                        image_path = os.path.join(class_path, file_name)
                        self.image_paths.append(image_path)
                        text_label = os.path.splitext(file_name)[0].replace('_', ' ')
                        self.texts.append(text_label)
                        self.labels.append(label_map[class_name])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        if self.image_transform:
            image = self.image_transform(image)

        text = self.texts[idx]
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        return {
            'image': image,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(self.labels[idx], dtype=torch.long)
        }

# Initialize tokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
max_len = 128

Data Slicing

In [None]:
# Load datasets and split data
train_dataset = CustomImageTextDataset(image_dir=TRAIN_PATH, tokenizer=tokenizer, max_len=max_len, image_transform=image_transform)
val_dataset = CustomImageTextDataset(image_dir=VAL_PATH, tokenizer=tokenizer, max_len=max_len, image_transform=image_transform)
test_dataset = CustomImageTextDataset(image_dir=TEST_PATH, tokenizer=tokenizer, max_len=max_len, image_transform=image_transform)

# Data loaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

Setting up Transfer Model

In [None]:
class GarbageModel(nn.Module):
    def __init__(self, num_classes, input_shape=(3, 224, 224), transfer=False):
        super(GarbageModel, self).__init__()

        self.distilbert = DistilBertModel.from_pretrained('distilbert-base-uncased')
        self.drop = nn.Dropout(0.3)
        self.text_out = nn.Linear(self.distilbert.config.hidden_size, 128)

        self.feature_extractor = models.resnet18(weights=models.ResNet18_Weights.DEFAULT if transfer else None)

        if transfer:
            for param in self.feature_extractor.parameters():
                param.requires_grad = False

        n_image_features = self._get_conv_output(input_shape)
        self.image_classifier = nn.Linear(n_image_features, 128)
        self.classifiermain = nn.Linear(256, num_classes)

    def _get_conv_output(self, shape):
        batch_size = 1
        tmp_input = torch.autograd.Variable(torch.rand(batch_size, *shape))
        output_feat = self.feature_extractor(tmp_input)
        n_size = output_feat.data.view(batch_size, -1).size(1)
        return n_size

    def forward(self, images, input_ids, attention_mask):
        text_features = self.distilbert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0]
        text_features = self.drop(text_features)
        text_features = self.text_out(text_features)

        image_features = self.feature_extractor(images)
        image_features = image_features.view(image_features.size(0), -1)
        image_features = self.image_classifier(image_features)

        combined_features = torch.cat((text_features, image_features), dim=1)
        output = self.classifiermain(combined_features)
        return output

# Initialize the model
num_classes = len(train_dataset.label_map)
model = GarbageModel(num_classes=num_classes).to(device)

Loss and metrics setup

In [None]:
optimizer = Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
scheduler = ExponentialLR(optimizer, gamma=0.9)
num_epochs = 10

Training and validation loop Setup

In [None]:
def train(model, train_loader, val_loader):
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for batch in train_loader:
            images = batch['image'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            optimizer.zero_grad()
            outputs = model(images, input_ids, attention_mask)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        scheduler.step()

        # Validation
        model.eval()
        val_loss = 0
        all_labels = []
        all_preds = []
        with torch.no_grad():
            for batch in val_loader:
                images = batch['image'].to(device)
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['label'].to(device)

                outputs = model(images, input_ids, attention_mask)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                _, preds = torch.max(outputs, 1)
                all_labels.extend(labels.cpu().numpy())
                all_preds.extend(preds.cpu().numpy())

        accuracy = accuracy_score(all_labels, all_preds)
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss/len(train_loader)}, Val Loss: {val_loss/len(val_loader)}, Val Accuracy: {accuracy:.4f}")

train(model, train_loader, val_loader)

Evaluation with confusion matrix and metrics Setup

In [None]:
def evaluate(model, test_loader):
    model.eval()
    all_labels = []
    all_preds = []
    with torch.no_grad():
        for batch in test_loader:
            images = batch['image'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            outputs = model(images, input_ids, attention_mask)
            _, preds = torch.max(outputs, 1)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    cm = confusion_matrix(all_labels, all_preds)
    accuracy = accuracy_score(all_labels, all_preds)
    sensitivity = recall_score(all_labels, all_preds, average='macro')
    specificity = precision_score(all_labels, all_preds, average='macro')

    print(f"Test Accuracy: {accuracy:.4f}, Sensitivity: {sensitivity:.4f}, Specificity: {specificity:.4f}")
    print("\nClassification Report:\n", classification_report(all_labels, all_preds, target_names=class_names))

    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Confusion Matrix")
    plt.show()
    plt.title("Accuracy")
    plt.plot(result.history["accuracy"])
    plt.show()
    plt.title("sensitivity")
    plt.plot(result.history["sensitivity"])
    plt.show()

evaluate(model, test_loader)

mean color distribution and class distribution 

In [None]:
# Function to calculate and display mean color distribution and class distribution
def plot_distributions(dataset, title_prefix):
    # Calculate color distributions
    color_distributions = []
    class_counts = {label: 0 for label in dataset.label_map.keys()}

    for i in range(len(dataset)):
        data = dataset[i]
        img = data['image'].numpy().transpose(1, 2, 0)  # Convert to HWC for histogram
        label = list(dataset.label_map.keys())[data['label'].item()]

        # Calculate histogram for each color channel and sum up
        hist_r, _ = np.histogram(img[:, :, 0], bins=256, range=(0, 1))
        hist_g, _ = np.histogram(img[:, :, 1], bins=256, range=(0, 1))
        hist_b, _ = np.histogram(img[:, :, 2], bins=256, range=(0, 1))
        color_distributions.append(hist_r + hist_g + hist_b)

        # Count class occurrences
        class_counts[label] += 1

    # Mean color distribution
    mean_color_distribution = np.mean(color_distributions, axis=0)

    # Plot distributions
    fig, a = plt.subplots(1, 2, figsize=(12, 5))

    # Mean color distribution plot
    a[0].bar(np.arange(256), mean_color_distribution)
    a[0].set_title(f"{title_prefix} Mean Color Distribution")
    a[0].set_xlabel("Color Value")
    a[0].set_ylabel("Number of Pixels")

    # Class distribution plot
    a[1].bar(class_counts.keys(), class_counts.values())
    a[1].set_title(f"{title_prefix} Class Distribution")
    a[1].set_xlabel("Classes")
    a[1].set_ylabel("Number of Images")

    plt.tight_layout()
    plt.show()

# Plot distributions for train, validation, and test datasets
plot_distributions(train_dataset, "Train Set")
plot_distributions(val_dataset, "Validation Set")
plot_distributions(test_dataset, "Test Set")