# Training ResNet-50 for object localization

Run the next cell if you want to re-train the model on custom data. The model will be saved as "coin_detector.pth" in the root folder.

**IMPORTANT:** The folder structure of the training dataset must follow exactly:
```
dataset
├── images
│   ├── L1010277.JPG
├── annotations
│   ├── L1010277 [1].xml
```


## Prepare the labelled training dataset into the required format

In [32]:
import glob
import shutil
import os

os.makedirs("dataset", exist_ok=True)
os.makedirs("dataset/images", exist_ok=True)
os.makedirs("dataset/annotations", exist_ok=True)

labelled_images_path = "Labelled_Training_Data"
image_count = 0
for image_type_folder in os.listdir(labelled_images_path):
    image_type_path = os.path.join(labelled_images_path, image_type_folder)
    for jpgfile in glob.iglob(os.path.join(image_type_path, "*.xml")):
        shutil.copy(jpgfile, "dataset/images")
        image_count += 1

print(f"Total annotated training images: {image_count}")

images_path = "Data/train"
image_count = 0
for image_type_folder in os.listdir(images_path):
    image_type_path = os.path.join(images_path, image_type_folder)
    for jpgfile in glob.iglob(os.path.join(image_type_path, "*.JPG")):
        shutil.copy(jpgfile, "dataset/annotations")
        image_count += 1

print(f"Total raw training images: {image_count}")


Total annotated training images: 81
Total raw training images: 81


## Setup the model

In [36]:
import torch
import torch.utils.data
import torchvision
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import xml.etree.ElementTree as ET
from PIL import Image


FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=0.0)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(

In [6]:
# File folder organization
unlabeled_data_folder = os.path.join("Data", "train")

data_labels_folder = os.path.join("Labelled_Training_Data")

coin_detection_dataset = os.path.join("dataset")
coin_detection_images = os.path.join("dataset", "images")
coin_detection_annotations = os.path.join("dataset", "annotations")
os.makedirs(coin_detection_dataset, exist_ok=True)
os.makedirs(os.path.join(coin_detection_dataset, "images"), exist_ok=True)
os.makedirs(os.path.join(coin_detection_dataset, "annotations"), exist_ok=True)


In [7]:
# Data/train/folder
for folder in os.listdir(unlabeled_data_folder):
    # Labelled_Training_Data/label_folder
    label_folder = os.path.join(data_labels_folder, folder)
    # folder = path to folder
    folder = os.path.join(unlabeled_data_folder, folder)
    for image in os.listdir(folder):
        # Find corresponding XML_file
        for label_file in os.listdir(label_folder):
            # Ex: L1010277.JPG --> L1010277
            if (image[:-4] in label_file):
                label_filepath = os.path.join(label_folder, label_file)
                image_filepath = os.path.join(folder, image)
                shutil.copy2(label_filepath, coin_detection_annotations)
                shutil.copy2(image_filepath, coin_detection_images)

In [None]:
# Parser for data
class CoinDataset(torch.utils.data.Dataset):
    def __init__(self, root, transforms=None):
        self.root = root
        self.transforms = transforms
        self.imgs = list(sorted(os.listdir(os.path.join(root, "images"))))
        self.annotations = list(sorted(os.listdir(os.path.join(root, "annotations"))))

    def __getitem__(self, idx):
        img_path = os.path.join(self.root, "images", self.imgs[idx])
        ann_path = os.path.join(self.root, "annotations", self.annotations[idx])
        img = Image.open(img_path).convert("RGB")

        boxes = []
        tree = ET.parse(ann_path)
        root = tree.getroot()
        for obj in root.findall("object"):
            bbox = obj.find("bndbox")
            xmin = float(bbox.find("xmin").text)
            ymin = float(bbox.find("ymin").text)
            xmax = float(bbox.find("xmax").text)
            ymax = float(bbox.find("ymax").text)
            boxes.append([xmin, ymin, xmax, ymax])

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.ones((len(boxes),), dtype=torch.int64)
        target = {"boxes": boxes, "labels": labels}

        if self.transforms is not None:
            img = self.transforms(img)

        return img, target

    def __len__(self):
        return len(self.imgs)


# Use a pre-trained Faster R-CNN model, pre-trained on ImageNet
model = fasterrcnn_resnet50_fpn(pretrained=True)
num_classes = 2  # Background and coin
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

# Data transforms
transforms = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
    ]
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

## Train on custom dataset

Run the next cell if you want to train the model on the custom (coins) dataset. 

In [None]:
# Load the dataset
dataset = CoinDataset("dataset/", transforms=transforms)
data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=4,
    shuffle=True,
    num_workers=0,
    collate_fn=lambda batch: tuple(zip(*batch)),
)

# Fine-tune the model based on our labelled training data
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    print("Beginning epoch", epoch + 1)
    for i, (images, targets) in enumerate(data_loader, 0):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        optimizer.zero_grad()
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        losses.backward()
        optimizer.step()

        running_loss += losses.item()
        if i % 10 == 9:  # Print every 10 batches
            print(
                f"Epoch [{epoch+1}/{num_epochs}], Batch [{i+1}/{len(data_loader)}], Loss: {running_loss/10:.4f}"
            )
            running_loss = 0.0

print("Training finished.")

# Save the trained model
torch.save(model.state_dict(), "coin_detector.pth")

# Load an existing model

Uncomment and run the next cell if you already have a trained model, and skip the above cell.

In [37]:
import torch
import torch.utils.data
import torchvision
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import xml.etree.ElementTree as ET
from PIL import Image

# Load the trained model
model = fasterrcnn_resnet50_fpn(pretrained=True)
num_classes = 2  # Background and coin
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

model.load_state_dict(torch.load("coin_detector.pth"))

<All keys matched successfully>

# Perform object localization on test images

**IMPORTANT:** The folder structure of input test images must follow exactly:
```
data
├── test
│   ├── L1010277.JPG
│   ├── L1010239.JPG
```
And the folder structure of output images will be:
```
output
├── L0000000.JPG                        <---------------- This is a directory
│   ├── 1920_1519_2595_2181.jpg         <---------------- This is an image of the 1st cropped coin, belonging to L0000000.JPG
│   ├── 2780_2159_3326_2685.jpg         <---------------- This is an image of the 2nd cropped coin, belonging to L0000000.JPG
├── L0000001.JPG                        
│   ├── 1920_1519_2595_2181.jpg         
│   ├── 2780_2159_3326_2685.jpg         
```
If an "output" folder already exists, it will get overwritten.

In [41]:
# Perform inference on test images
model.eval()
test_data_path = "Data/test"

os.makedirs("output", exist_ok=True)
for img_name in os.listdir(test_data_path):
    img_path = os.path.join(test_data_path, img_name)
    # Check if the file is an image
    if not img_name.endswith(".JPG") and not img_name.endswith(".jpg"):
        continue
    img = Image.open(img_path).convert("RGB")
    img_tensor = transforms(img).unsqueeze(0).to(device)
    with torch.no_grad():
        outputs = model(img_tensor)
    # Create a directory for each image, where cropped coins will be saved
    os.mkdir(f"output/{img_name}")
    boxes = outputs[0]["boxes"].cpu().numpy()
    # Crop coins from current image
    for box in boxes:
        xmin, ymin, xmax, ymax = box.astype(int)
        coin = img.crop((xmin, ymin, xmax, ymax))
        cropped_coin_path = os.path.join(
            f"output/{img_name}", f"{xmin}_{ymin}_{xmax}_{ymax}.jpg"
        )
        coin.save(cropped_coin_path)

L0000106.JPG
L0000112.JPG
L0000099.JPG
L0000072.JPG
L0000066.JPG
L0000067.JPG
L0000073.JPG
L0000098.JPG
L0000113.JPG
L0000107.JPG
L0000139.JPG
L0000111.JPG
L0000105.JPG
L0000059.JPG
L0000065.JPG
L0000071.JPG
L0000070.JPG
L0000064.JPG
L0000058.JPG
L0000104.JPG
L0000110.JPG
L0000138.JPG
L0000114.JPG
L0000100.JPG
L0000128.JPG
.DS_Store
L0000060.JPG
L0000074.JPG
L0000048.JPG
L0000049.JPG
L0000075.JPG
L0000061.JPG
L0000129.JPG
L0000101.JPG
L0000115.JPG
L0000103.JPG
L0000117.JPG
L0000088.JPG
L0000077.JPG
L0000063.JPG
L0000062.JPG
L0000076.JPG
L0000089.JPG
L0000116.JPG
L0000102.JPG
L0000159.JPG
L0000039.JPG
L0000011.JPG
L0000005.JPG
L0000004.JPG
L0000010.JPG
L0000038.JPG
L0000158.JPG
L0000006.JPG
L0000012.JPG
L0000013.JPG
L0000007.JPG
L0000003.JPG
L0000017.JPG
L0000016.JPG
L0000002.JPG
L0000160.JPG
L0000148.JPG
L0000014.JPG
L0000000.JPG
L0000028.JPG
L0000029.JPG
L0000001.JPG
L0000015.JPG
L0000149.JPG
L0000161.JPG
L0000144.JPG
L0000150.JPG
L0000018.JPG
L0000030.JPG
L0000024.JPG
L0000025.JPG
L0