In [None]:
!nvidia-smi

Sat Jun  4 07:16:28 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   41C    P0    28W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
!pip install tensorboardX
!pip install --force-reinstall albumentations==1.0.3
!pip install opencv-python-headless==4.5.2.52
!pip install tensorboardcolab


import torch.optim as optim
import matplotlib.patches as patches
import albumentations as A
import cv2
import numpy as np
import os
import pandas as pd
import torch.nn.functional as F
from PIL import Image, ImageFile
from torch.utils.data import Dataset, DataLoader
from albumentations.pytorch import ToTensorV2
from collections import Counter
from torch.utils.data import DataLoader
from tqdm import tqdm
import random
import torch
import torch.nn as nn
import warnings
from tensorboardX import SummaryWriter
from torchsummary import summary

warnings.filterwarnings("ignore")

In [None]:
torch.backends.cudnn.benchmark = True

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DATASET = '/content/drive/MyDrive'
IMG_DIR = DATASET + "/images1/"
LABEL_DIR = DATASET + "/labels1/"
CHECKPOINT_FILE = "checkpoint.pth.tar"

MAP_IOU_THRESH = 0.5
NUM_EPOCHS = 200
WEIGHT_DECAY = 1e-4
NUM_WORKERS = 4
NMS_IOU_THRESH = 0.45
BATCH_SIZE = 8
IMAGE_SIZE = 416
NUM_CLASSES = 20
LEARNING_RATE = 3e-5
CONF_THRESHOLD = 0.05
scale = 1.1


PIN_MEMORY = True
LOAD_MODEL = True
SAVE_MODEL = True

S = [IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8]
ANCHORS = [
    [(0.28, 0.22), (0.38, 0.48), (0.9, 0.78)],
    [(0.07, 0.15), (0.15, 0.11), (0.14, 0.29)],
    [(0.02, 0.03), (0.04, 0.07), (0.08, 0.06)],
]

PASCAL_CLASSES = [
    "aeroplane",
    "bicycle",
    "bird",
    "boat",
    "bottle",
    "bus",
    "car",
    "cat",
    "chair",
    "cow",
    "diningtable",
    "dog",
    "horse",
    "motorbike",
    "person",
    "pottedplant",
    "sheep",
    "sofa",
    "train",
    "tvmonitor"
]

"""
UTILS
"""
train_transforms = A.Compose(
    [
        A.LongestMaxSize(max_size=int(IMAGE_SIZE * scale)),
        A.PadIfNeeded(
            min_height=int(IMAGE_SIZE * scale),
            min_width=int(IMAGE_SIZE * scale),
            border_mode=cv2.BORDER_CONSTANT,
        ),
        A.RandomCrop(width=IMAGE_SIZE, height=IMAGE_SIZE),
        A.ColorJitter(brightness=0.6, contrast=0.6,
                      saturation=0.6, hue=0.6, p=0.4),
        A.OneOf(
            [
                A.ShiftScaleRotate(
                    rotate_limit=20, p=0.5, border_mode=cv2.BORDER_CONSTANT
                ),
                A.IAAAffine(shear=15, p=0.5, mode="constant"),
            ],
            p=1.0,
        ),
        A.HorizontalFlip(p=0.5),
        A.Blur(p=0.1),
        A.CLAHE(p=0.1),
        A.Posterize(p=0.1),
        A.ToGray(p=0.1),
        A.ChannelShuffle(p=0.05),
        A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
        ToTensorV2(),
    ],
    bbox_params=A.BboxParams(
        format="yolo", min_visibility=0.4, label_fields=[],),
)
test_transforms = A.Compose(
    [
        A.LongestMaxSize(max_size=IMAGE_SIZE),
        A.PadIfNeeded(
            min_height=IMAGE_SIZE, min_width=IMAGE_SIZE, border_mode=cv2.BORDER_CONSTANT
        ),
        A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
        ToTensorV2(),
    ],
    bbox_params=A.BboxParams(
        format="yolo", min_visibility=0.4, label_fields=[]),
)



def iou_width_height(boxes1, boxes2):

    intersection = torch.min(boxes1[..., 0], boxes2[..., 0]) * torch.min(
        boxes1[..., 1], boxes2[..., 1]
    )
    union = (
        boxes1[..., 0] * boxes1[..., 1] +
        boxes2[..., 0] * boxes2[..., 1] - intersection
    )
    return intersection / union


def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):


    if box_format == "midpoint":
        box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
        box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
        box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
        box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2
        box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
        box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
        box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
        box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2

    x1 = torch.max(box1_x1, box2_x1)
    y1 = torch.max(box1_y1, box2_y1)
    x2 = torch.min(box1_x2, box2_x2)
    y2 = torch.min(box1_y2, box2_y2)

    intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
    box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
    box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))

    return intersection / (box1_area + box2_area - intersection + 1e-6)


def non_max_suppression(bboxes, iou_threshold, threshold, box_format="corners"):
    assert type(bboxes) == list

    boxes = [box for box in bboxes if box[1] > threshold]
    boxes = sorted(boxes, key=lambda x: x[1], reverse=True)
    boxes_after_nms = []

    while boxes:
        chosen_box = boxes.pop(0)

        boxes = [
            box
            for box in boxes
            if box[0] != chosen_box[0]
            or intersection_over_union(
                torch.tensor(chosen_box[2:]),
                torch.tensor(box[2:]),
                box_format=box_format,
            )
            < iou_threshold
        ]

        boxes_after_nms.append(chosen_box)
    return boxes_after_nms


def mean_average_precision(
    pred_boxes, true_boxes, iou_threshold=0.5, box_format="midpoint", num_classes=20
):

    average_precisions = []
    epsilon = 1e-6

    for c in range(num_classes):
        detections = []
        ground_truths = []

        for detection in pred_boxes:
            if detection[1] == c:
                detections.append(detection)

        for true_box in true_boxes:
            if true_box[1] == c:
                ground_truths.append(true_box)

        amount_bboxes = Counter([gt[0] for gt in ground_truths])
        for key, val in amount_bboxes.items():
            amount_bboxes[key] = torch.zeros(val)

        detections.sort(key=lambda x: x[2], reverse=True)
        TP = torch.zeros((len(detections)))
        FP = torch.zeros((len(detections)))
        total_true_bboxes = len(ground_truths)

        if total_true_bboxes == 0:
            continue

        for detection_idx, detection in enumerate(detections):
            ground_truth_img = [
                bbox for bbox in ground_truths if bbox[0] == detection[0]
            ]

            num_gts = len(ground_truth_img)
            best_iou = 0

            for idx, gt in enumerate(ground_truth_img):
                iou = intersection_over_union(
                    torch.tensor(detection[3:]),
                    torch.tensor(gt[3:]),
                    box_format=box_format,
                )

                if iou > best_iou:
                    best_iou = iou
                    best_gt_idx = idx

            if best_iou > iou_threshold:
                if amount_bboxes[detection[0]][best_gt_idx] == 0:
                    TP[detection_idx] = 1
                    amount_bboxes[detection[0]][best_gt_idx] = 1
                else:
                    FP[detection_idx] = 1
            else:
                FP[detection_idx] = 1

        TP_cumsum = torch.cumsum(TP, dim=0)
        FP_cumsum = torch.cumsum(FP, dim=0)
        recalls = TP_cumsum / (total_true_bboxes + epsilon)
        precisions = TP_cumsum / (TP_cumsum + FP_cumsum + epsilon)
        precisions = torch.cat((torch.tensor([1]), precisions))
        recalls = torch.cat((torch.tensor([0]), recalls))
        average_precisions.append(torch.trapz(precisions, recalls))

    return sum(average_precisions) / len(average_precisions)


def get_evaluation_bboxes(
    loader,
    model,
    iou_threshold,
    anchors,
    threshold,
    box_format="midpoint",
    device="cuda",
):
    model.eval()
    train_idx = 0
    all_pred_boxes = []
    all_true_boxes = []
    for batch_idx, (x, labels) in enumerate(tqdm(loader)):
        x = x.to(device)

        with torch.no_grad():
            predictions = model(x)

        batch_size = x.shape[0]
        bboxes = [[] for _ in range(batch_size)]
        for i in range(3):
            S = predictions[i].shape[2]
            anchor = torch.tensor([*anchors[i]]).to(device) * S
            boxes_scale_i = cells_to_bboxes(
                predictions[i], anchor, S=S, is_preds=True
            )
            for idx, (box) in enumerate(boxes_scale_i):
                bboxes[idx] += box

        true_bboxes = cells_to_bboxes(
            labels[2], anchor, S=S, is_preds=False
        )

        for idx in range(batch_size):
            nms_boxes = non_max_suppression(
                bboxes[idx],
                iou_threshold=iou_threshold,
                threshold=threshold,
                box_format=box_format,
            )

            for nms_box in nms_boxes:
                all_pred_boxes.append([train_idx] + nms_box)

            for box in true_bboxes[idx]:
                if box[1] > threshold:
                    all_true_boxes.append([train_idx] + box)

            train_idx += 1

    model.train()
    return all_pred_boxes, all_true_boxes


def cells_to_bboxes(predictions, anchors, S, is_preds=True):
    BATCH_SIZE = predictions.shape[0]
    num_anchors = len(anchors)
    box_predictions = predictions[..., 1:5]
    if is_preds:
        anchors = anchors.reshape(1, len(anchors), 1, 1, 2)
        box_predictions[..., 0:2] = torch.sigmoid(box_predictions[..., 0:2])
        box_predictions[..., 2:] = torch.exp(
            box_predictions[..., 2:]) * anchors
        scores = torch.sigmoid(predictions[..., 0:1])
        best_class = torch.argmax(predictions[..., 5:], dim=-1).unsqueeze(-1)
    else:
        scores = predictions[..., 0:1]
        best_class = predictions[..., 5:6]

    cell_indices = (
        torch.arange(S)
        .repeat(predictions.shape[0], 3, S, 1)
        .unsqueeze(-1)
        .to(predictions.device)
    )
    x = 1 / S * (box_predictions[..., 0:1] + cell_indices)
    y = 1 / S * (box_predictions[..., 1:2] +
                 cell_indices.permute(0, 1, 3, 2, 4))
    w_h = 1 / S * box_predictions[..., 2:4]
    converted_bboxes = torch.cat(
        (best_class, scores, x, y, w_h), dim=-1).reshape(BATCH_SIZE, num_anchors * S * S, 6)
    return converted_bboxes.tolist()


def check_class_accuracy(model, loader, threshold):
    model.eval()
    tot_class_preds, correct_class = 0, 0
    tot_noobj, correct_noobj = 0, 0
    tot_obj, correct_obj = 0, 0

    for idx, (x, y) in enumerate(tqdm(loader)):
        x = x.to(DEVICE)
        with torch.no_grad():
            out = model(x)

        for i in range(3):
            y[i] = y[i].to(DEVICE)
            obj = y[i][..., 0] == 1  
            noobj = y[i][..., 0] == 0  

            correct_class += torch.sum(
                torch.argmax(out[i][..., 5:][obj], dim=-1) == y[i][..., 5][obj]
            )
            tot_class_preds += torch.sum(obj)

            obj_preds = torch.sigmoid(out[i][..., 0]) > threshold
            correct_obj += torch.sum(obj_preds[obj] == y[i][..., 0][obj])
            tot_obj += torch.sum(obj)
            correct_noobj += torch.sum(obj_preds[noobj] == y[i][..., 0][noobj])
            tot_noobj += torch.sum(noobj)
    print(
        f"Class accuracy is: {(correct_class/(tot_class_preds+1e-16))*100:2f}%")
    print(f"No obj accuracy is: {(correct_noobj/(tot_noobj+1e-16))*100:2f}%")
    print(f"Obj accuracy is: {(correct_obj/(tot_obj+1e-16))*100:2f}%")

    model.train()

class YOLODataset(Dataset):
    def __init__(
        self,
        csv_file,
        img_dir,
        label_dir,
        anchors,
        image_size=416,
        S=[13, 26, 52],
        C=20,
        transform=None,
    ):
        self.annotations = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.label_dir = label_dir
        self.image_size = image_size
        self.transform = transform
        self.S = S
        self.anchors = torch.tensor(
            anchors[0] + anchors[1] + anchors[2])  # for all 3 scales
        self.num_anchors = self.anchors.shape[0]
        self.num_anchors_per_scale = self.num_anchors // 3
        self.C = C
        self.ignore_iou_thresh = 0.5

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        label_path = os.path.join(
            self.label_dir, self.annotations.iloc[index, 1])
        bboxes = np.roll(np.loadtxt(fname=label_path,
                         delimiter=" ", ndmin=2), 4, axis=1).tolist()
        img_path = os.path.join(self.img_dir, self.annotations.iloc[index, 0])
        image = np.array(Image.open(img_path).convert("RGB"))

        if self.transform:
            augmentations = self.transform(image=image, bboxes=bboxes)
            image = augmentations["image"]
            bboxes = augmentations["bboxes"]

        targets = [torch.zeros((self.num_anchors // 3, S, S, 6))
                   for S in self.S]
        for box in bboxes:
            iou_anchors = iou_width_height(
                torch.tensor(box[2:4]), self.anchors)
            anchor_indices = iou_anchors.argsort(descending=True, dim=0)
            x, y, width, height, class_label = box
            has_anchor = [False] * 3 
            for anchor_idx in anchor_indices:
                scale_idx = anchor_idx // self.num_anchors_per_scale
                anchor_on_scale = anchor_idx % self.num_anchors_per_scale
                S = self.S[scale_idx]
                i, j = int(S * y), int(S * x)
                anchor_taken = targets[scale_idx][anchor_on_scale, i, j, 0]
                if not anchor_taken and not has_anchor[scale_idx]:
                    targets[scale_idx][anchor_on_scale, i, j, 0] = 1
                    x_cell, y_cell = S * x - j, S * y - i 
                    width_cell, height_cell = (
                        width * S,
                        height * S,
                    )  
                    box_coordinates = torch.tensor(
                        [x_cell, y_cell, width_cell, height_cell]
                    )
                    targets[scale_idx][anchor_on_scale,
                                       i, j, 1:5] = box_coordinates
                    targets[scale_idx][anchor_on_scale,
                                       i, j, 5] = int(class_label)
                    has_anchor[scale_idx] = True

                elif not anchor_taken and iou_anchors[anchor_idx] > self.ignore_iou_thresh:
                    targets[scale_idx][anchor_on_scale,
                                       i, j, 0] = -1  

        return image, tuple(targets)




def get_loaders(train_csv_path, test_csv_path):

    train_dataset = YOLODataset(
        train_csv_path,
        transform=train_transforms,
        S=[IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8],
        img_dir=IMG_DIR,
        label_dir=LABEL_DIR,
        anchors=ANCHORS,
    )
    test_dataset = YOLODataset(
        test_csv_path,
        transform=test_transforms,
        S=[IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8],
        img_dir=IMG_DIR,
        label_dir=LABEL_DIR,
        anchors=ANCHORS,
    )
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        shuffle=True,
        drop_last=False,
    )
    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        shuffle=False,
        drop_last=False,
    )

    return train_loader, test_loader



class YoloLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()
        self.bce = nn.BCEWithLogitsLoss()
        self.entropy = nn.CrossEntropyLoss()
        self.sigmoid = nn.Sigmoid()

        self.lambda_class = 1
        self.lambda_noobj = 10
        self.lambda_obj = 1
        self.lambda_box = 10

    def forward(self, predictions, target, anchors, writer, length, epoch, i):
        obj = target[..., 0] == 1 
        noobj = target[..., 0] == 0 

        no_object_loss = self.bce(
            (predictions[..., 0:1][noobj]), (target[..., 0:1][noobj]),
        )

        anchors = anchors.reshape(1, 3, 1, 1, 2)
        box_preds = torch.cat([self.sigmoid(predictions[..., 1:3]), torch.exp(
            predictions[..., 3:5]) * anchors], dim=-1)
        ious = intersection_over_union(
            box_preds[obj], target[..., 1:5][obj]).detach()
        object_loss = self.mse(self.sigmoid(
            predictions[..., 0:1][obj]), ious * target[..., 0:1][obj])

        predictions[..., 1:3] = self.sigmoid(
            predictions[..., 1:3])
        target[..., 3:5] = torch.log(
            (1e-16 + target[..., 3:5] / anchors)
        ) 
        box_loss = self.mse(predictions[..., 1:5][obj], target[..., 1:5][obj])

        class_loss = self.entropy(
            (predictions[..., 5:][obj]), (target[..., 5][obj].long()),
        )
       
        
        if i % 2000 == 0:
          
            writer.add_scalar(
                            "Regression_loss",
                            self.lambda_box * box_loss,
                            epoch,
                            
                        )
            writer.add_scalar(
                            "Confidence_loss",
                            self.lambda_obj * object_loss + self.lambda_noobj * no_object_loss,
                            epoch,
                        )
            writer.add_scalar(
                            "Classification_loss",
                            self.lambda_class * class_loss,
                            epoch,
                        )
            writer.add_scalar(
                            "Total_loss",
                            self.lambda_box * box_loss
                            + self.lambda_obj * object_loss
                            + self.lambda_noobj * no_object_loss
                            + self.lambda_class * class_loss,
                            epoch,
                        )


        return (
            self.lambda_box * box_loss
            + self.lambda_obj * object_loss
            + self.lambda_noobj * no_object_loss
            + self.lambda_class * class_loss
        )


class Conv(nn.Module):
    def __init__(self, in_channels, out_channels, bn_act=True, **kwargs):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels,
                              bias=not bn_act, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels)
        self.leaky = Mish()
        self.use_bn_act = bn_act

    def forward(self, x):
        if self.use_bn_act:
            return self.leaky(self.bn(self.conv(x)))
        else:
            return self.conv(x)


class Mish(nn.Module):
    def __init__(self):
        super(Mish, self).__init__()

    def forward(self, x):
        return x * torch.tanh(F.softplus(x))


class ConvBlock(nn.Module):

    def __init__(self, dim, drop_path=0.0, layer_scale_init_value=1e-6):
        super().__init__()
        self.dwconv = nn.Conv2d(
            dim, dim, kernel_size=7, padding=3, groups=dim
        )  # depthwise conv
        self.norm = nn.BatchNorm2d(dim)
        self.pwconv1 = nn.Linear(
            dim, 4 * dim
        )  # pointwise/1x1 convs, implemented with linear layers
        self.act = Mish()
        self.pwconv2 = nn.Linear(4 * dim, dim)
        self.gamma = (
            nn.Parameter(layer_scale_init_value *
                         torch.ones((dim)), requires_grad=True)
            if layer_scale_init_value > 0
            else None
        )
        self.drop_path = nn.Identity()

    def forward(self, x):
        input = x
        x = self.dwconv(x)
        x = self.norm(x)
        x = x.permute(0, 2, 3, 1)  # (N, C, H, W) -> (N, H, W, C)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        if self.gamma is not None:
            x = self.gamma * x
        x = x.permute(0, 3, 1, 2)  # (N, H, W, C) -> (N, C, H, W)

        x = input + self.drop_path(x)
        return x


class Scale(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.pred = nn.Sequential(
            Conv(in_channels, 2 * in_channels, kernel_size=3, padding=1),
            Conv(
                2 * in_channels, (num_classes + 5) * 3, bn_act=False, kernel_size=1
            ),
        )
        self.num_classes = num_classes

    def forward(self, x):
        return (
            self.pred(x)
            .reshape(x.shape[0], 3, self.num_classes + 5, x.shape[2], x.shape[3])
            .permute(0, 1, 3, 4, 2)
        )


class SpatialPyramidPooling(nn.Module):
    def __init__(self, feature_channels, pool_sizes=[5, 9, 13]):
        super(SpatialPyramidPooling, self).__init__()

        self.head_conv = nn.Sequential(
            Conv(feature_channels[-1],
                 feature_channels[-1] // 2, kernel_size=1),
        )

        self.maxpools = nn.ModuleList(
            [
                nn.MaxPool2d(pool_size, 1, pool_size // 2)
                for pool_size in pool_sizes
            ]
        )
        self.__initialize_weights()

    def forward(self, x):
        x = self.head_conv(x)
        features = [maxpool(x) for maxpool in self.maxpools]
        features = torch.cat([x] + features, dim=1)

        return features

    def __initialize_weights(self):

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.weight.data.normal_(0, 0.01)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()


class ResBlock(nn.Module):
    def __init__(self, channels, use_residual=True, num_repeats=1):
        super().__init__()
        self.layers = nn.ModuleList()
        for repeat in range(num_repeats):
            self.layers += [
                nn.Sequential(
                    ConvBlock(channels)
                )
            ]

        self.use_residual = use_residual
        self.num_repeats = num_repeats

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


class MyModel(nn.Module):
    def __init__(self, in_channels=3, num_classes=20):
        super().__init__()
        self.num_classes = num_classes
        self.in_channels = in_channels
        self.depths = [64, 128, 256, 512, 1024]
        self.num_repeats = [1, 2, 5, 5, 4]
        self.up = [512, 256, 128]
        self.feature_channels = [256, 512, 1024]

        start = nn.Sequential(
            Conv(in_channels, 32, kernel_size=3, stride=1, padding=1),
            Conv(32, 64, kernel_size=3, stride=2, padding=1),
        )

        self.downsample_layers = nn.ModuleList()
        self.downsample_layers.append(start)
        for i in range(4):
            downsample = nn.Sequential(
                Conv(self.depths[i], self.depths[i+1],
                     kernel_size=3, stride=2, padding=1),
            )
            self.downsample_layers.append(downsample)

        self.stages = nn.ModuleList()
        for i in range(5):
            stage = nn.Sequential(
                ResBlock(channels=self.depths[i], use_residual=True,
                         num_repeats=self.num_repeats[i])
            )

            self.stages.append(stage)

        scale1 = Scale(512, num_classes=self.num_classes)
        scale2 = Scale(256, num_classes=self.num_classes)
        scale3 = Scale(128, num_classes=self.num_classes)
        self.scales = nn.ModuleList()
        self.scales.append(scale1)
        self.scales.append(scale2)
        self.scales.append(scale3)

        route1 = nn.Sequential(
            Conv(1024, 1024, kernel_size=3, stride=1, padding=1),
            ResBlock(1024, use_residual=False),
            Conv(1024, 512, kernel_size=1),
        )

        route2 = nn.Sequential(
            Conv(768, 256, kernel_size=1, stride=1, padding=0),
            Conv(256, 512, kernel_size=3, stride=1, padding=1),
            ResBlock(512, use_residual=False),
            Conv(512, 256, kernel_size=1),
        )

        route3 = nn.Sequential(
            Conv(384, 128, kernel_size=1, stride=1, padding=0),
            Conv(128, 256, kernel_size=3, stride=1, padding=1),
            ResBlock(256, use_residual=False),
            Conv(256, 128, kernel_size=1),
        )
        self.routes = nn.ModuleList()
        self.routes.append(route1)
        self.routes.append(route2)
        self.routes.append(route3)

        self.upsample = nn.ModuleList()
        for i in range(2):
            upsample = nn.Sequential(
                Conv(self.up[i], self.up[i+1],
                     kernel_size=1, stride=1, padding=0),
                nn.Upsample(scale_factor=2)
            )
            self.upsample.append(upsample)

        self.spp = SpatialPyramidPooling(self.feature_channels)
        self.downstream_conv = nn.Sequential(
            Conv(2048, 1024, kernel_size=3, stride=1, padding=1),
        )

    def forward(self, x):
        """¨¨¨¨¨¨¨¨¨¨BACKBONE¨¨¨¨¨¨¨¨¨¨¨"""
        routes = []
        for i in range(5):
            x = self.downsample_layers[i](x)
            x = self.stages[i](x)
            if self.num_repeats[i] == 5:
                routes.append(x)
        """¨¨¨¨¨¨¨¨¨¨BACKBONE¨¨¨¨¨¨¨¨¨¨¨"""

        """-----------NECK---------------"""
        x = self.spp(x)
        x = self.downstream_conv(x)
        """-----------NECK--------------"""

        """-----------HEAD-------------"""

        outs = []
        for i in range(2):
            x = self.routes[i](x)
            outs.append(self.scales[i](x))
            x = self.upsample[i](x)
            x = torch.cat([x, routes[-i-1]], dim=1)

        x = self.routes[-1](x)
        outs.append(self.scales[-1](x))
        """-----------HEAD-------------"""

        return outs






def train_fn(train_loader, model, optimizer, loss_fn, scaler, scaled_anchors, writer, epoch):
    loop = tqdm(train_loader, leave=True)
    losses = []
    length = len(train_loader)
    for batch_idx, (x, y) in enumerate(loop):
        x = x.to(DEVICE)
        y0, y1, y2 = (
            y[0].to(DEVICE),
            y[1].to(DEVICE),
            y[2].to(DEVICE),
        )

        with torch.cuda.amp.autocast():
            out = model(x)
            loss = (
                loss_fn(out[0], y0, scaled_anchors[0], writer, length, epoch, batch_idx)
                + loss_fn(out[1], y1, scaled_anchors[1], writer, length, epoch, batch_idx)
                + loss_fn(out[2], y2, scaled_anchors[2], writer, length, epoch, batch_idx)
            )
            
        losses.append(loss.item())
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        mean_loss = sum(losses) / len(losses)
        loop.set_postfix(loss=mean_loss)


class Trainer(object):
  def __init__(self, weight_path, train_loader, test_loader):
      self.start_epoch = 0
      self.best_mAP = 0.0
      self.weight_path = weight_path
      self.train_loader = train_loader
      self.test_loader = test_loader
      self.model = MyModel().to(DEVICE)
      self.optimizer = optim.AdamW(self.model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
      self.loss_fn = YoloLoss()
      self.scaler = torch.cuda.amp.GradScaler()
      self.scaled_anchors = (
        torch.tensor(ANCHORS)
        * torch.tensor(S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
    ).to(DEVICE)
      self.save = '/content/drive/MyDrive/YOLOv3_version2.pth.tar'


  def load_checkpoint(self,checkpoint_file, lr):
    print("=> Loading checkpoint")
    print(checkpoint_file)
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    self.model.load_state_dict(checkpoint["state_dict"])
    self.optimizer.load_state_dict(checkpoint["optimizer"])
    self.start_epoch = checkpoint["epoch"] + 1

    for param_group in self.optimizer.param_groups:
        param_group["lr"] = lr

  def save_checkpoint(self, filename, epoch, mAP):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": self.model.state_dict(),
        "optimizer": self.optimizer.state_dict(),
        "epoch": epoch,
        "best_mAP": mAP,
    }
    torch.save(checkpoint, filename)

  def train(self, writer):
    mAP = 0

    if LOAD_MODEL:
        self.load_checkpoint(
            self.weight_path, LEARNING_RATE
        )


    for epoch in range(self.start_epoch, NUM_EPOCHS):
        print(epoch)
        loop = tqdm(self.train_loader, leave=True)
        losses = []
        length = len(self.train_loader)
        for batch_idx, (x, y) in enumerate(loop):
            x = x.to(DEVICE)
            y0, y1, y2 = (
                y[0].to(DEVICE),
                y[1].to(DEVICE),
                y[2].to(DEVICE),
            )

            with torch.cuda.amp.autocast():
                out = self.model(x)
                loss = (
                    self.loss_fn(out[0], y0, self.scaled_anchors[0], writer, length, epoch, batch_idx)
                    + self.loss_fn(out[1], y1, self.scaled_anchors[1], writer, length, epoch, batch_idx)
                    + self.loss_fn(out[2], y2, self.scaled_anchors[2], writer, length, epoch, batch_idx)
                )

            losses.append(loss.item())
            self.optimizer.zero_grad()
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()

            mean_loss = sum(losses) / len(losses)
            loop.set_postfix(loss=mean_loss)

        if SAVE_MODEL:
            self.save_checkpoint(filename=self.save, epoch=epoch, mAP=mAP)
            
        if epoch % 5 == 0:
            check_class_accuracy(self.model, self.test_loader,
                                 threshold=CONF_THRESHOLD)
            pred_boxes, true_boxes = get_evaluation_bboxes(
                self.test_loader,
                self.model,
                iou_threshold=NMS_IOU_THRESH,
                anchors=ANCHORS,
                threshold=CONF_THRESHOLD,
            )
            mapval = mean_average_precision(
                pred_boxes,
                true_boxes,
                iou_threshold=MAP_IOU_THRESH,
                box_format="midpoint",
                num_classes=NUM_CLASSES,
            )
            print(f"MAP: {mapval.item()}")
            self.model.train()
            writer.add_scalar(
                            "MAP",
                            mapval.item(),
                            epoch,
                        )
          
def main():
    log_path = 'YOLO/COLAB_train/tensorboard'
    writer = SummaryWriter(logdir=log_path + "/MyModel")
    train_loader, test_loader = get_loaders(
        train_csv_path="/content/drive/MyDrive/train.csv", test_csv_path="/content/drive/MyDrive/test.csv"
    )
    weight_path = '/content/drive/MyDrive/YOLOv3_version1.pth.tar'


    Trainer(weight_path, train_loader, test_loader).train(writer)





In [None]:
#from google.colab import drive
#drive.mount('/content/drive')

In [None]:
%load_ext tensorboard%
tensorboard --logdir=/content/drive/MyDrive/puvodni_znova

In [None]:
main()