# Task 3

In [18]:
import torch
from torchvision.transforms import functional as F
from torchvision.datasets import CocoDetection
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from PIL import Image
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import Compose, ToTensor

In [19]:
class LegoDataset(CocoDetection):
    def __getitem__(self, idx):
        img, target = super().__getitem__(idx)
        img = F.to_tensor(img)
        
        # Convert target to desired format
        boxes = []
        labels = []
        for obj in target:
            xmin, ymin, width, height = obj['bbox']
            boxes.append([xmin, ymin, xmin + width, ymin + height])
            labels.append(obj['category_id'])
        
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        target = {"boxes": boxes, "labels": labels}
        
        return img, target

dataset = LegoDataset('data/imgs', 'faster_r_cnn.json')

loading annotations into memory...
Done (t=0.04s)
creating index...
index created!


In [20]:
def get_model(num_classes):
    # Load a pre-trained model for classification and return only the features
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

    # Get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    
    # Replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    return model

num_classes = 2 # 1 class (lego) + background
model = get_model(num_classes)

In [21]:
# Split the dataset into train and test sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))


# Move model to the correct device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

# Construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

# Training function
def train_one_epoch(model, optimizer, data_loader, device, epoch):
    model.train()
    for images, targets in data_loader:
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        
        # Backward pass
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

num_epochs = 1
for epoch in range(num_epochs):
    train_one_epoch(model, optimizer, train_loader, device, epoch)
    print(f"Epoch {epoch} finished.")

# Save the model
torch.save(model.state_dict(), 'lego_fasterrcnn.pth')


Epoch 0 finished.


In [23]:


# Load the trained model
model.load_state_dict(torch.load('lego_fasterrcnn.pth'))
model.eval()

def predict(image_path):
    image = Image.open(image_path).convert("RGB")
    image_tensor = F.to_tensor(image).unsqueeze(0).to(device)
    with torch.no_grad():
        prediction = model(image_tensor)
    return prediction

# Make a prediction
image_path = 'data/imgs/30_1.jpg'
prediction = predict(image_path)
print(prediction)


[{'boxes': tensor([[134.4769, 119.0376, 146.9966, 129.0684],
        [105.3467, 114.2255, 116.8834, 121.9928],
        [105.5964, 196.2340, 125.1398, 209.3532],
        [ 75.2217, 187.8367,  89.4976, 210.7990],
        [125.0777, 140.9013, 148.8088, 160.8630],
        [152.5215, 131.8268, 165.5538, 142.7635],
        [ 94.1864,  15.4764, 107.1057,  31.8908],
        [110.8575,  92.5573, 131.7252, 102.0790],
        [108.3662, 155.0198, 125.0086, 167.4651],
        [129.9701,  14.4197, 140.9499,  22.2021],
        [151.6155,  72.7795, 169.2446,  90.5884],
        [145.1336,  51.8578, 156.2168,  62.7207],
        [ 75.4288, 150.8042,  99.8903, 169.6171],
        [144.3372, 103.5472, 154.2623, 113.0652],
        [ 79.7614,  30.8931,  88.3207,  39.4577],
        [ 84.5133,  86.0216,  98.9886,  97.4525],
        [107.8348,  49.0068, 126.9868,  65.1537],
        [123.0213, 167.2093, 132.4824, 176.1359],
        [167.1742, 166.3310, 188.6536, 193.1087],
        [ 78.1562,  69.6970,  98.3133, 