In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import os
import numpy as np
import cv2
from tqdm import tqdm
from segment_anything import sam_model_registry, SamPredictor
from transformers import CLIPProcessor, CLIPModel
from transunet import TransUNet

# ======= CONFIGURATION =======
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 8
IMAGE_SIZE = 224
ITEM_C_DIR = "C_prime"
ITEM_D_DIR = "D"

# Defining the preprocessing pipeline
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.4815, 0.4578, 0.4082), std=(0.2686, 0.2613, 0.2758))
])

#Function to load the images and their corresponding masks
def load_data(image_dir, mask_dir):
    image_files = sorted([f for f in os.listdir(image_dir) if f.endswith(".jpg")])
    
    data = []
    
    for image_name in image_files:
        image_path = os.path.join(image_dir, image_name)
        mask_path = os.path.join(mask_dir, image_name)
        
        image = Image.open(image_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")
        mask = torch.tensor(np.array(mask), dtype=torch.long).unsqueeze(0)
        
        data.append((image, mask))
    
    return data

# Loading the datasets for training and testing
train_dataset = load_data(image_dir=os.path.join(ITEM_C_DIR, "images"), mask_dir=os.path.join(ITEM_C_DIR, "masks"))
test_dataset = load_data(image_dir=os.path.join(ITEM_D_DIR, "images"), mask_dir=os.path.join(ITEM_D_DIR, "masks"))

# Creating DataLoaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Loading the pretrained TransUNet model
model = TransUNet(
        img_size=224,
        in_channels=3,
        out_channels=2,
        pretrained=True
    )
model.to(DEVICE)
model.eval()

# Define a simple training loop
def train(model, train_loader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    for images, masks in tqdm(train_loader):
        images = images.to(DEVICE)
        masks = masks.to(DEVICE)
        
        optimizer.zero_grad()
        outputs = model(images).logits
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()

    print(f"Training loss: {running_loss / len(train_loader)}")

# Define a simple evaluation loop
def evaluate(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, masks in test_loader:
            images = images.to(DEVICE)
            masks = masks.to(DEVICE)
            
            outputs = model(images).logits
            predicted = torch.argmax(outputs, dim=1)
            
            total += masks.numel()
            correct += (predicted == masks).sum().item()

    accuracy = correct / total
    print(f"Test Accuracy: {accuracy * 100:.2f}%")

# Initialize loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Train and evaluate
train(model, train_loader, criterion, optimizer)
evaluate(model, test_loader)