In [None]:
import os
import torch
import torch.nn as nn
from torchvision import transforms, models
from torchvision.models import ResNet50_Weights
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from transformers import BertModel, BertTokenizer
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score, roc_auc_score
from PIL import Image
import re

In [None]:
# ─── Check Device And Save For Later ────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
# ─── Create Checkpoint Directory ────────────────────────────────────
os.makedirs("checkpoints", exist_ok=True)

In [None]:
# ─── Dataset Class ──────────────────────────────────────────────────
class GarbageDataset(Dataset):
    def __init__(self, root_dir, transform=None, tokenizer=None):
        self.root_dir = root_dir
        self.transform = transform
        self.tokenizer = tokenizer
        self.classes = ["Black", "Blue", "Green", "TTR"]
        self.data = []
        for label in self.classes:
            class_dir = os.path.join(root_dir, label)
            for file_name in os.listdir(class_dir):
                if file_name.endswith(".jpg") or file_name.endswith(".png"):
                    text_description = re.sub(r"\d+", "", file_name.split(".")[0])
                    self.data.append(
                        (
                            os.path.join(class_dir, file_name),
                            text_description,
                            self.classes.index(label),
                        )
                    )

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, text, label = self.data[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        if self.tokenizer:
            text = self.tokenizer(
                text,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=128,
            )
        return image, text, label

In [None]:
# ─── Transformer Class ──────────────────────────────────────────────
transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ]
)

# ─── Tokenizer Class ────────────────────────────────────────────────
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

In [None]:
# ─── Initiate Datasets ──────────────────────────────────────────────
    # ─── Training ───────────────────────────────────────────────
train_dataset = GarbageDataset(
    root_dir="garbage_data/CVPR_2024_dataset_Train",
    transform=transform,
    tokenizer=tokenizer,
)
    # ─── Validation ─────────────────────────────────────────────
val_dataset = GarbageDataset(
    root_dir="garbage_data/CVPR_2024_dataset_Val",
    transform=transform,
    tokenizer=tokenizer,
)
    # ─── Testing ────────────────────────────────────────────────
test_dataset = GarbageDataset(
    root_dir="garbage_data/CVPR_2024_dataset_Test",
    transform=transform,
    tokenizer=tokenizer,
)

# ─── Initiate Dataloaders ───────────────────────────────────────────
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [None]:
# ─── Model Class ────────────────────────────────────────────────────
class MultimodalModel(nn.Module):
    def __init__(self, num_classes=4):
        super(MultimodalModel, self).__init__()
        # ─── Image Model ────────────────────────────────────
        self.image_model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        num_ftrs = self.image_model.fc.in_features
        self.image_model.fc = nn.Identity()

        # ─── Text Model ─────────────────────────────────────
        self.text_model = BertModel.from_pretrained("bert-base-uncased")

        # ─── Combined Classifiers With Hidden Layers ────────
        self.fc1 = nn.Linear(num_ftrs + self.text_model.config.hidden_size, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, num_classes)
        self.sigmoid = nn.Sigmoid()
        self.dropout = nn.Dropout(0.2)

    def forward(self, image, text):
        # ─── Image Features ─────────────────────────────────
        image_features = self.image_model(image)

        # ─── Text Features ──────────────────────────────────
        text_features = self.text_model(**text).last_hidden_state[:, 0, :]

        # ─── Combine Features ───────────────────────────────
        combined_features = torch.cat((image_features, text_features), dim=1)

        # ─── Pass Through Additional Hidden Layers ──────────
        x = self.fc1(combined_features)
        x = self.sigmoid(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        x = self.dropout(x)
        x = self.fc3(x)
        output = self.sigmoid(x)

        return output

In [None]:
# ─── Initiate Model ─────────────────────────────────────────────────
model = MultimodalModel(num_classes=4)
model.to(device)

In [None]:
# ─── Criterion ──────────────────────────────────────────────────────
criterion = nn.CrossEntropyLoss()
# ─── Optimizer ──────────────────────────────────────────────────────
optimizer = optim.Adam(model.parameters(), lr=1e-4)
# ─── Scheduler ──────────────────────────────────────────────────────
scheduler = ReduceLROnPlateau(optimizer, "min", patience=3, factor=0.1, verbose=True)

In [None]:
# ─── Training Function ──────────────────────────────────────────────
def train_model(
    model,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    scheduler,
    num_epochs=20,
    patience=5,
):
    best_val_loss = float("inf")
    patience_counter = 0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for images, texts, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)
            texts = {key: val.squeeze(1).to(device) for key, val in texts.items()}

            optimizer.zero_grad()
            outputs = model(images, texts)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * images.size(0)

        epoch_loss = running_loss / len(train_loader.dataset)
        val_loss = evaluate_model(model, val_loader, criterion)

        print(
            f"<|||  Epoch {epoch+1}/{num_epochs}, Train Loss: {epoch_loss:.4f}, Val Loss: {val_loss:.4f}  |||>"
        )

        scheduler.step(val_loss)

        # ─── Save Model Checkpoint ──────────────────────────
        checkpoint_path = f"checkpoints/model_epoch_{epoch+1}.pth"
        torch.save(model.state_dict(), checkpoint_path)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save(model.state_dict(), "best_model.pth")
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print("Early stopping")
            break


# ─── Evaluation Function ────────────────────────────────────────────
def evaluate_model(model, val_loader, criterion, title="Validation"):
    model.eval()
    running_loss = 0.0
    all_labels = []
    all_preds = []
    all_probs = []
    with torch.no_grad():
        for images, texts, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)
            texts = {key: val.squeeze(1).to(device) for key, val in texts.items()}

            outputs = model(images, texts)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * images.size(0)

            probs = nn.functional.softmax(outputs, dim=1)
            preds = torch.argmax(probs, dim=1)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    val_loss = running_loss / len(val_loader.dataset)
    accuracy = accuracy_score(all_labels, all_preds)
    conf_matrix = confusion_matrix(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average="weighted")
    auc = roc_auc_score(all_labels, all_probs, average="weighted", multi_class="ovr")

    print('<========================================================================================>')
    print(
        f"{title} | Loss: {val_loss:.4f}, Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}, AUC: {auc:.4f}"
    )
    print()
    print(f"Confusion Matrix:\n{conf_matrix}")

    return val_loss

In [None]:
# ─── Running The Training ───────────────────────────────────────────
train_model(
    model,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    scheduler,
    num_epochs=10,
    patience=5,
)

In [None]:
# ─── Load The Best Model ────────────────────────────────────────────
model.load_state_dict(torch.load("best_model.pth"))

# ─── Evaluate On Validation Set ─────────────────────────────────────
evaluate_model(model, val_loader, criterion, "Validation")

# ─── Evaluate On Test Set ───────────────────────────────────────────
evaluate_model(model, test_loader, criterion, "Test")