In [1]:
import torch
import os
import cv2
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.transforms import functional as F
from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_Weights

In [2]:
class MinecraftDataset(Dataset):
    def __init__(self, img_dir, label_dir, width, height, transforms=None):
        self.img_dir = img_dir
        self.label_dir = label_dir
        self.transforms = transforms
        self.width = width
        self.height = height
        self.imgs = [f for f in os.listdir(img_dir) if f.endswith('.jpg')]

    def __getitem__(self, idx):
        img_name = self.imgs[idx]
        img_path = os.path.join(self.img_dir, img_name)
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
        img /= 255.0

        label_path = os.path.join(self.label_dir, img_name.replace('.jpg', '.txt'))
        boxes = []
        labels = []
        
        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                lines = f.readlines()
                for line in lines:
                    parts = list(map(float, line.strip().split()))
                    if len(parts) >= 5:
                        class_id = parts[0]
                        x_c, y_c, w, h = parts[1], parts[2], parts[3], parts[4]
                        
                        x_min = (x_c - w/2) * self.width
                        y_min = (y_c - h/2) * self.height
                        x_max = (x_c + w/2) * self.width
                        y_max = (y_c + h/2) * self.height
                        
                        boxes.append([x_min, y_min, x_max, y_max])
                        labels.append(int(class_id) + 1)

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        
        if boxes.shape[0] == 0:
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.as_tensor([], dtype=torch.int64)
        else:
            labels = torch.as_tensor(labels, dtype=torch.int64)
        
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = torch.tensor([idx])

        img = torch.as_tensor(img.transpose((2, 0, 1)), dtype=torch.float32)

        return img, target

    def __len__(self):
        return len(self.imgs)

In [3]:
def collate_fn(batch):
    return tuple(zip(*batch))

In [4]:
dataset_location = "./dataset"

TRAIN_DIR_IMG = f"{dataset_location}/train/images"
TRAIN_DIR_LBL = f"{dataset_location}/train/labels"

dataset_train = MinecraftDataset(TRAIN_DIR_IMG, TRAIN_DIR_LBL, 640, 640)
data_loader = DataLoader(dataset_train, batch_size=4, shuffle=True, collate_fn=collate_fn)

In [5]:
def get_model(num_classes):
    weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
    model = fasterrcnn_resnet50_fpn(weights=weights)
    
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    
    return model

In [6]:
num_classes_dataset = 11
model = get_model(num_classes_dataset + 1)

In [7]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=0.0)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(

In [8]:
for param in model.backbone.parameters():
    param.requires_grad = False

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

In [9]:
num_epochs_warmup = 5
num_epochs_finetune = 10

In [10]:
for epoch in range(num_epochs_warmup):
    model.train()
    total_loss = 0
    for images, targets in data_loader:
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        total_loss += losses.item()
    print(f"Epoch {epoch} Loss: {total_loss/len(data_loader)}")

Epoch 0 Loss: 0.38052579299138306
Epoch 1 Loss: 0.304536140344122
Epoch 2 Loss: 0.28230795634158873
Epoch 3 Loss: 0.2666182705969141
Epoch 4 Loss: 0.2514713292712705


In [11]:
for param in model.backbone.parameters():
    param.requires_grad = True

optimizer = torch.optim.SGD(model.parameters(), lr=0.0005, momentum=0.9, weight_decay=0.0005)

In [None]:
for epoch in range(num_epochs_finetune):
    real_epoch = epoch + num_epochs_warmup
    model.train()
    total_loss = 0
    for images, targets in data_loader:
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        total_loss += losses.item()
    print(f"Epoch {real_epoch} (FT) Loss: {total_loss/len(data_loader)}")

In [None]:
torch.save(model.state_dict(), 'minecraft_model_finetuned.pth')