In [1]:
from PIL import Image
import torchvision.transforms as transforms
import torch
import numpy as np
import math
import sys

### Training model

#### Implementing custom dataset

In [2]:
sys.path.append('./dataset')
from deep_fashion_2 import DeepFashion2Dataset 

In [3]:
data_path = {
    "train": "../data/train",
    "val": "../data/validation",
}

In [4]:
def collate_fn(batch):
    return tuple(zip(*batch))

train_start_index = 15000
train_end_index = 25000    
train_dataset = DeepFashion2Dataset(data_path['train'],train_start_index,train_end_index)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True,collate_fn=collate_fn, num_workers=0)

val_start_index = 3000
val_end_index = val_start_index + (train_end_index - train_start_index)/4
val_dataset = DeepFashion2Dataset(data_path['val'],val_start_index,val_end_index)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=True,collate_fn=collate_fn, num_workers=0)

#### Setting Model

In [None]:
is_import_model = True

model_path = "../model/faster_rcnn_resnet101_v2.pth"

In [5]:
sys.path.append('./model')
from faster_r_cnn import FasterRCNNResNet101 

In [6]:
model =  torch.load(model_path) if is_import_model else FasterRCNNResNet101(num_classes=14)

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)



FasterRCNNResNet101(
  (model): FasterRCNN(
    (transform): GeneralizedRCNNTransform(
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        Resize(min_size=(800,), max_size=1333, mode='bilinear')
    )
    (backbone): BackboneWithFPN(
      (body): IntermediateLayerGetter(
        (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (bn1): FrozenBatchNorm2d(64, eps=1e-05)
        (relu): ReLU(inplace=True)
        (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        (layer1): Sequential(
          (0): Bottleneck(
            (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn1): FrozenBatchNorm2d(64, eps=1e-05)
            (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn2): FrozenBatchNorm2d(64, eps=1e-05)
            (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)


In [7]:
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

#### Looping train

In [None]:
num_epochs = 5

for epoch in range(num_epochs):
    model.train()     
       
    running_loss = 0.0
    for i, (images, targets) in enumerate(train_dataloader):
        
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        
        optimizer.zero_grad()

        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        
        losses.backward()
        optimizer.step()

        running_loss += losses.item()
    
    avg_train_loss = running_loss / len(train_dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Average Training Loss: {avg_train_loss}")        

    lr_scheduler.step()


    val_running_loss = 0.0
    with torch.no_grad():
        for i, (images, targets) in enumerate(val_dataloader):
            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())

            val_running_loss += losses.item()


    avg_val_loss = val_running_loss / len(val_dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Average Validation Loss: {avg_val_loss}")
    
print("Training complete!")


#### Saving trained model

In [9]:
torch.save(model.state_dict(), "../model/faster_rcnn_resnet101_v3.pth")