<a href="https://colab.research.google.com/github/jeongin7103/BoxNSegAI/blob/main/2DBB/%EA%B9%80%EC%A0%95%EC%9D%B8/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import sys
sys.path.append('/content/drive/MyDrive/Colab Notebooks/2dbb_ji')

In [3]:
import time
import torch.backends.cudnn as cudnn
import torch.optim
from model import SSD300, MultiBoxLoss
import torch.utils.data
from utils import *
from datasets import CustomDataset

global start_epoch, label_map, epoch, checkpoint, decay_lr_at

# Data parameters
data_folder = '/content/drive/MyDrive/Colab Notebooks/2dbb_ji/'

In [4]:
# Model parameters
n_classes = len(label_map)
print(n_classes)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Learning parameters
checkpoint = None  # train mode on/off
batch_size = 4
iterations = 3
# workers = 4
print_freq = 100
lr = 1e-3
decay_lr_at = [80000, 100000]
decay_lr_to = 0.1  # decay learning rate to this fraction of the existing learning rate
momentum = 0.9  # momentum
weight_decay = 5e-4  # weight decay
grad_clip = None
cudnn.benchmark = True

11
cuda


In [5]:
def train(train_loader, model, criterion, optimizer, epoch):
    model.train()  # training mode enables dropout

    batch_time = AverageMeter()  # forward prop. + back prop. time
    data_time = AverageMeter()  # data loading time
    losses = AverageMeter()  # loss

    start = time.time()

    # Batches
    for i, (images, boxes, labels) in enumerate(train_loader):
        # images: (N, 3, 300, 300)
        data_time.update(time.time() - start)

        # Move to default device
        # images
        images = images.to(device)  # (batch_size (N), 3, 300, 300)
        boxes = [b.to(device) for b in boxes]
        labels = [l.to(device) for l in labels]

        # Forward prop.
        # 여기서 model.py의 forward 함수의 인자로 넣어줄 images 가 전달된다.
        predicted_locs, predicted_scores = model(images)  # (N, 8732, 4), (N, 8732, n_classes)

        # Loss

        loss = criterion(predicted_locs, predicted_scores, boxes, labels)  # scalar

        # Backward prop.
        optimizer.zero_grad()
        loss.backward()

        # Clip gradients, if necessary
        if grad_clip is not None:
            clip_gradient(optimizer, grad_clip)

        # Update model
        optimizer.step()

        losses.update(loss.item(), images.size(0))
        batch_time.update(time.time() - start)

        start = time.time()

        # Print status
        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data Time {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(epoch, i, len(train_loader),
                                                                  batch_time=batch_time,
                                                                  data_time=data_time, loss=losses))
    del predicted_locs, predicted_scores, images, boxes, labels  # free some memory since their histories may be stored



In [6]:
if checkpoint is None:
    start_epoch = 0
    model = SSD300(n_classes=11)
    biases = list()
    not_biases = list()
    for param_name, param in model.named_parameters():
        if param.requires_grad:
            if param_name.endswith('.bias'):
                biases.append(param)
            else:
                not_biases.append(param)
    optimizer = torch.optim.SGD(params=[{'params': biases, 'lr': 2 * lr}, {'params': not_biases}],
                                lr=lr, momentum=momentum, weight_decay=weight_decay)

else:
    checkpoint = torch.load(checkpoint)
    start_epoch = checkpoint['epoch'] + 1
    print('\nLoaded checkpoint from epoch %d.\n' % start_epoch)
    model = checkpoint['model']
    optimizer = checkpoint['optimizer']



In [7]:
# Move to default device
model = model.to(device)

# loss 함수 지정
criterion = MultiBoxLoss(priors_cxcy=model.priors_cxcy).to(device)

# Custom dataloaders
train_dataset = CustomDataset(data_folder, split='training')
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                                           collate_fn=train_dataset.collate_fn,
                                           pin_memory=True)  # note that we're passing the collate function here


In [8]:
# epochs = iterations // (len(train_dataset) // 32)
epochs = 10
# decay_lr_at = [it // (len(train_dataset) // 32) for it in decay_lr_at]
decay_lr_at = [8,9]
print(epochs)
print(decay_lr_at)

10
[8, 9]


In [9]:
# Epochs
for epoch in range(start_epoch, epochs):
    # Decay learning rate at particular epochs
    if epoch in decay_lr_at:
        adjust_learning_rate(optimizer, decay_lr_to)

    # One epoch's training, train 함수로 학습 진행
    train(train_loader=train_loader, model=model, criterion=criterion,
          optimizer=optimizer,
          epoch=epoch)
    # Save checkpoint
    save_checkpoint(epoch, model, optimizer)

Epoch: [0][0/200]	Batch Time 9.726 (9.726)	Data Time 1.045 (1.045)	Loss 22.1520 (22.1520)	
Epoch: [0][100/200]	Batch Time 2.458 (3.181)	Data Time 2.237 (2.865)	Loss 36.8416 (22.1324)	
Epoch: [1][0/200]	Batch Time 1.337 (1.337)	Data Time 1.173 (1.173)	Loss 19.7111 (19.7111)	
Epoch: [1][100/200]	Batch Time 2.127 (2.129)	Data Time 1.888 (1.932)	Loss 23.7204 (33.0984)	
Epoch: [2][0/200]	Batch Time 2.297 (2.297)	Data Time 2.062 (2.062)	Loss 33.6972 (33.6972)	
Epoch: [2][100/200]	Batch Time 2.096 (2.138)	Data Time 1.877 (1.937)	Loss 20.9102 (31.1142)	
Epoch: [3][0/200]	Batch Time 1.523 (1.523)	Data Time 1.295 (1.295)	Loss 34.2077 (34.2077)	
Epoch: [3][100/200]	Batch Time 1.489 (2.130)	Data Time 1.336 (1.932)	Loss 22.7399 (33.8085)	
Epoch: [4][0/200]	Batch Time 2.155 (2.155)	Data Time 1.915 (1.915)	Loss 27.3457 (27.3457)	
Epoch: [4][100/200]	Batch Time 1.660 (2.065)	Data Time 1.502 (1.871)	Loss 17.4686 (20.3029)	
Epoch: [5][0/200]	Batch Time 1.032 (1.032)	Data Time 0.877 (0.877)	Loss 16.5404 