In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import re
import torch.nn as nn
from torchvision import models
from transformers import BertModel, BertTokenizer
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score, roc_auc_score

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

os.makedirs("checkpoints", exist_ok=True)

In [None]:
# ─── Dataset Class ──────────────────────────────────────────────────
class GarbageDataset(Dataset):
    def __init__(self, root_dir, transform=None, tokenizer=None):
        self.root_dir = root_dir
        self.transform = transform
        self.tokenizer = tokenizer
        self.classes = ["Black", "Blue", "Green", "TTR"]
        self.data = []
        for label in self.classes:
            class_dir = os.path.join(root_dir, label)
            for file_name in os.listdir(class_dir):
                if file_name.endswith(".jpg") or file_name.endswith(".png"):
                    text_description = re.sub(r"\d+", "", file_name.split(".")[0])
                    self.data.append(
                        (
                            os.path.join(class_dir, file_name),
                            text_description,
                            self.classes.index(label),
                        )
                    )

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, text, label = self.data[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        if self.tokenizer:
            text = self.tokenizer(
                text,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=128,
            )
        return image, text, label