In [4]:
import os
import json
import torch
import torch.nn as nn
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
import timm
import numpy as np
BASE_DIR = r"C:\Users\meme machine\Downloads\Capstone-20250618T001711Z-1-001\Capstone\project"
IMG_SIZE = 224
PATCH_SIZE = 16
NUM_CLASSES = 1204
BATCH_SIZE = 32
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [5]:

transform = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.ColorJitter(0.1, 0.1, 0.1),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize([0.5]*3, [0.5]*3)
])

class ManualLVISDataset(Dataset):
    def __init__(self, image_root, ann_file, transforms=None):
        self.image_root = image_root
        self.transforms = transforms
        with open(ann_file, 'r') as f:
            data = json.load(f)
        self.images = data['images']
        self.annotations = data['annotations']
        self.image_to_anns = {}
        for ann in self.annotations:
            self.image_to_anns.setdefault(ann['image_id'], []).append(ann)
        self.id_to_image = {img['id']: img for img in self.images}

    def __getitem__(self, index):
        img_info = self.images[index]
        img_id = img_info['id']
        file_name = img_info.get("file_name", f"{img_id:012d}.jpg")
        img_path = os.path.join(self.image_root, file_name)
        image = Image.open(img_path).convert("RGB")
        if self.transforms:
            image = self.transforms(image)
        anns = self.image_to_anns.get(img_id, [])
        return image, anns

    def __len__(self):
        return len(self.images)

def get_dataset(split, transform):
    img_dir = os.path.join(BASE_DIR, f"{split}2017")
    ann_path = os.path.join(BASE_DIR, "annotations", f"lvis_v1_{split}.json")
    return ManualLVISDataset(img_dir, ann_path, transform)

train_dataset = get_dataset("train", transform)
val_dataset = get_dataset("val", transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))

class ViTDetectionModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.vit = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=0)
        self.hidden_dim = self.vit.embed_dim
        self.head = nn.Linear(self.hidden_dim, 4 + num_classes)

    def forward(self, x):
        B = x.shape[0]
        feats = self.vit.forward_features(x)[:, 1:] 
        out = self.head(feats)  
        return out

model = ViTDetectionModel(NUM_CLASSES).to(DEVICE)





In [None]:
def get_patch_center_coords(img_size=224, patch_size=16):
    coords = []
    for y in range(0, img_size, patch_size):
        for x in range(0, img_size, patch_size):
            cx = x + patch_size // 2
            cy = y + patch_size // 2
            coords.append((cx, cy))
    return coords  

PATCH_CENTERS = get_patch_center_coords(IMG_SIZE, PATCH_SIZE)

def assign_gt_to_patches(boxes):
    patch_targets = [-1] * len(PATCH_CENTERS)
    bbox_targets = [None] * len(PATCH_CENTERS)

    for box in boxes:
        x1, y1, x2, y2 = box.tolist()
        cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
        dists = [(px - cx)**2 + (py - cy)**2 for px, py in PATCH_CENTERS]
        min_idx = np.argmin(dists)
        patch_targets[min_idx] = box
    return patch_targets

def detection_loss(preds, targets):
    cls_loss_fn = nn.CrossEntropyLoss()
    reg_loss_fn = nn.SmoothL1Loss()
    total_cls_loss, total_reg_loss = 0.0, 0.0
    for i in range(len(preds)):
        pred = preds[i] 
        gt = targets[i]
        boxes = gt['boxes'].to(DEVICE)
        labels = gt['labels'].to(DEVICE)

        if boxes.numel() == 0 or labels.numel() == 0:
            continue  

        assignments = assign_gt_to_patches(boxes)
        for patch_idx, box in enumerate(assignments):
            if box is None:
                continue

            label = torch.tensor(labels[0], dtype=torch.long, device=DEVICE)
            box_tensor = box.to(DEVICE) if isinstance(box, torch.Tensor) else torch.tensor(box, dtype=torch.float32, device=DEVICE)

            pred_box = pred[patch_idx, :4]
            pred_cls = pred[patch_idx, 4:]

            total_reg_loss += reg_loss_fn(pred_box, box_tensor)
            total_cls_loss += cls_loss_fn(pred_cls.unsqueeze(0), label.unsqueeze(0))



    return total_cls_loss + total_reg_loss

def preprocess_targets(targets):
    results = []
    for ann_list in targets:
        boxes, labels = [], []
        for ann in ann_list:
            x, y, w, h = ann["bbox"]
            boxes.append(torch.tensor([x, y, x + w, y + h]))
            labels.append(ann["category_id"])
        results.append({
            "boxes": torch.stack(boxes) if boxes else torch.zeros((0, 4)),
            "labels": torch.tensor(labels, dtype=torch.int64) if labels else torch.zeros((0,), dtype=torch.int64)
        })
    return results

def train_one_epoch(model, dataloader, optimizer):
    model.train()
    total_loss = 0.0
    for images, targets in dataloader:
        images = torch.stack(images).to(DEVICE)
        targets = preprocess_targets(targets)
        preds = model(images)
        loss = detection_loss(preds, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

@torch.no_grad()
def evaluate(model, dataloader):
    model.eval()
    for images, targets in dataloader:
        images = torch.stack(images).to(DEVICE)
        preds = model(images)
        print("Validation output shape:", preds.shape)
        break

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
for epoch in range(5):
    train_loss = train_one_epoch(model, train_loader, optimizer)
    print(f"[Epoch {epoch+1}] Train Loss: {train_loss:.4f}")
    evaluate(model, val_loader)

In [None]:
with open(os.path.join(BASE_DIR, "annotations", "lvis_v1_train.json"), "r") as f:
    lvis_data = json.load(f)

id_to_class = {cat["id"]: cat["name"] for cat in lvis_data["categories"]}

In [6]:
# model_save_path = os.path.join(BASE_DIR, "vit_llvis_model.pth")
# torch.save(model.state_dict(), model_save_path)
# print(f"Model saved to: {model_save_path}")
model = ViTDetectionModel(NUM_CLASSES).to(DEVICE)

model.load_state_dict(torch.load(os.path.join(BASE_DIR, "vit_llvis_model.pth"), map_location=DEVICE))
model.eval()
print("Model loaded and ready.")

Model loaded and ready.
