In [2]:
####

In [3]:
import torch
from torch import nn
import torchvision.transforms.functional as F
import torch.nn.functional as f

In [4]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
from torchvision.utils import make_grid
from torchvision import transforms
from tqdm.notebook import tqdm

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
def Reverse(lst):
    return [ele for ele in reversed(lst)]

In [7]:
def show_tensor_images(image_tensor):
    size = (image_tensor.shape[1] , image_tensor.shape[2] , image_tensor.shape[3])
    num_images = image_tensor.shape[0]
    image_shifted = image_tensor
    image_unflat = image_shifted.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

In [8]:

def iou_width_height(boxes1, boxes2):

    intersection = torch.min(boxes1[..., 0], boxes2[..., 0]) * torch.min(
        boxes1[..., 1], boxes2[..., 1]
    )
    union = (
        boxes1[..., 0] * boxes1[..., 1] + boxes2[..., 0] * boxes2[..., 1] - intersection
    )
    return intersection / union

In [9]:
def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):

    if box_format == "midpoint":
        box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
        box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
        box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
        box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2
        box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
        box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
        box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
        box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2

    if box_format == "corners":
        box1_x1 = boxes_preds[..., 0:1]
        box1_y1 = boxes_preds[..., 1:2]
        box1_x2 = boxes_preds[..., 2:3]
        box1_y2 = boxes_preds[..., 3:4]
        box2_x1 = boxes_labels[..., 0:1]
        box2_y1 = boxes_labels[..., 1:2]
        box2_x2 = boxes_labels[..., 2:3]
        box2_y2 = boxes_labels[..., 3:4]

    x1 = torch.max(box1_x1, box2_x1)
    y1 = torch.max(box1_y1, box2_y1)
    x2 = torch.min(box1_x2, box2_x2)
    y2 = torch.min(box1_y2, box2_y2)

    intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
    box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
    box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))

    return intersection / (box1_area + box2_area - intersection + 1e-6)

In [10]:
class Conv(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels , 
                 kernel_size = (3 , 3) , 
                 stride = (1 , 1) , 
                 padding = 1 , 
                 use_norm = True , 
                 use_activation = True , 
                 use_grps = False):
        super(Conv , self).__init__()

        self.use_norm = use_norm
        self.use_activation = use_activation
        
        groups = out_channels if use_grps else 1
        self.conv1 = nn.Conv2d(in_channels , 
                               out_channels , 
                               kernel_size , 
                               stride , 
                               padding , 
                               groups = groups)
        if self.use_norm:
            self.norm = nn.BatchNorm2d(out_channels)
        if self.use_activation:
            self.relu = nn.ReLU6()

    def forward(self , x):
        x = self.conv1(x)
        if self.use_norm:
            x = self.norm(x)
        if self.use_activation:
            x = self.relu(x)
        return x

In [11]:
class ConvT(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels , 
                 kernel_size = (2 , 2) , 
                 stride = (2 , 2) , 
                 padding = 0 , 
                 use_norm = True , 
                 use_activation = True):
        super(ConvT , self).__init__()

        self.use_norm = use_norm
        self.use_activation = use_activation

        self.convT = nn.ConvTranspose2d(in_channels , 
                                        out_channels , 
                                        kernel_size , 
                                        stride ,
                                        padding)
        if self.use_norm:
            self.norm = nn.InstanceNorm2d(out_channels)
        if self.use_activation:
            self.activation = nn.LeakyReLU(0.2)

    def forward(self , x):
        x = self.convT(x)
        if self.use_norm:
            x = self.norm(x)
        if self.use_activation:
            x = self.activation(x)
        return x

In [12]:
class SqueezeExcitation(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels):
        super(SqueezeExcitation , self).__init__()

        self.adp_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv1 = Conv(in_channels , out_channels , kernel_size=(1 , 1) , padding=0)
        self.silu = nn.SiLU()
        self.conv2 = Conv(out_channels , in_channels , kernel_size=(1 , 1) , padding=0)
        self.sigmoid = nn.Sigmoid()

    def forward(self , x):
        x_ = x.clone()
        x = self.adp_avg_pool(x)
        x = self.conv1(x)
        x = self.silu(x)
        x = self.conv2(x)
        x = self.sigmoid(x)
        x = x * x_
        return x

In [13]:
class Inverted_Res_Block(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels , 
                 t , 
                 padding , 
                 kernel_size , 
                 stride , 
                 reduction = 0.4 ,  
                 ):
        super(Inverted_Res_Block , self).__init__()
        
        self.use_residual = in_channels == out_channels and stride == 1
        hidden_dim = in_channels * t
        self.conv1 = Conv(in_channels , hidden_dim , kernel_size=(1 , 1) , stride=(1 , 1) , padding=0)
        self.conv2 = Conv(hidden_dim , hidden_dim , kernel_size=kernel_size , stride=stride , use_grps=True , padding=padding)
        self.conv3 = Conv(hidden_dim , out_channels , kernel_size=(1 , 1) , stride=(1 , 1) , padding=0)

        reduced_dim = int(in_channels / reduction)
        self.squeeze_excitation = SqueezeExcitation(out_channels , reduced_dim)

    def forward(self , x):
        x_ = x.clone()
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.squeeze_excitation(x)
        if self.use_residual:
            return x + x_
        else :
            return x

In [14]:

config = [
    # expand_ratio, channels, repeats, stride, kernel_size
    [1, 16, 1, 1, 3],
    'S' , 
    [6, 24, 2, 2, 3],
    'S' , 
    [6, 40, 2, 2, 3],
    'S' , 
    [6, 80, 3, 2, 3],
    [6, 112, 3, 1, 3],
    'S' , 
    [6, 192, 4, 2, 3],
    [6, 320, 1, 1, 3],
    'S'
]

phi_values = {
    # tuple of: (phi_value, resolution, drop_rate)
    "b0": (0, 512, 0.2),  
    "b1": (0.5, 640, 0.2),
    "b2": (1, 768, 0.3),
    "b3": (2, 896, 0.3),
    "b4": (3, 1024, 0.4),
    "b5": (4, 1280, 0.4),
    "b6": (5, 1280, 0.5),
    "b7": (6, 1536, 0.5),
}

In [15]:
class Save_Block(nn.Module):
    def __init__(self):
        super(Save_Block , self).__init__()

    def forward(self , x):
        return x

In [16]:
class Efficient_Net(nn.Module):
    def __init__(self , 
                 in_channels , 
                 version , 
                 out_channels_model , 
                 config=config, 
                 phi_values=phi_values):
        super(Efficient_Net , self).__init__()

        self.config=config
        self.phi_values = phi_values
        depth_factor , width_factor = self._get_scale_params(version)
        out_channels = int(1280 * width_factor)

        self.layers = self._get_layers(depth_factor , width_factor , out_channels)

        self.cls = nn.Linear(out_channels , out_channels_model)
        self.adp_avg_pool = nn.AdaptiveAvgPool2d(1)

    def _get_layers(self , depth_factor , width_factor , out_channels_model):
        layers = nn.ModuleList()
        
        channels = int(32 * width_factor)
        layers.append(Conv(3 , channels , stride=(2 , 2)))
        in_channels = channels
        for layer in config:
            if isinstance(layer , list):
                t , out_channels , repeats , stride , kernel_size = layer
                out_channels = 4 * int(int(out_channels * width_factor)/4)
                repeats = int(repeats * depth_factor)

                for repeat in range(repeats):
                    layers.append(
                        Inverted_Res_Block(
                            in_channels , 
                            out_channels , 
                            t , 
                            stride = stride if repeat == 0 else 1 , 
                            kernel_size = kernel_size , 
                            padding = kernel_size // 2
                        )
                    )
                    in_channels = out_channels
            elif isinstance(layer , str):
                layers.append(Save_Block())
        layers.append(Conv(in_channels , out_channels_model , kernel_size=(1 , 1) , stride=(1 , 1) , padding=0))
        return layers

    def _get_scale_params(self , version , alpha=1.2 , beta=1.1):
        phi_value , resolution , drop_rate = self.phi_values[version]
        depth_factor = alpha ** phi_value
        width_factor = beta ** phi_value
        return depth_factor , width_factor

    def forward(self , x):
        x_out = []
        for layer in self.layers:
            if isinstance(layer , Save_Block):
                x_out.append(x)
            x = layer(x)
            #print(x.shape)
        return x_out

In [None]:
def test(version = 'b0'):
    version = version
    phi, res, drop_rate = phi_values[version]
    num_examples, num_classes = 2, 10
    x = torch.randn((num_examples, 3, res, res)).to(device)
    model = Efficient_Net(
        version=version,
        out_channels_model=num_classes,
        in_channels = 3
    ).to(device)

    x = model(x)
    return x
'''x = test()
for x_ in x:
    print(x_.shape)'''

In [18]:
in_channels_list_ = {
    'b0': [16 , 24 , 40 , 112 , 320] , 
    'b1': [16 , 24 , 40 , 116 , 332] , 
    'b2': [16 , 24 , 44 , 120 , 352] , 
    'b3': [16 , 28 , 48 , 132 , 384] , 
    'b4': [20 , 28 , 52 , 148 , 424] , 
    'b5': [20 , 32 , 56 , 160 , 468] , 
    'b6': [24 , 36 , 64 , 180 , 512] , 
    'b7': [28 , 40 , 68 , 196 , 564]    
}

In [19]:
class BiFPN_Layer(nn.Module):
    def __init__(self , 
                 in_channels_list = [2048 , 1024 , 512 , 256 , 128] , 
                 out_channels = [128 , 256 , 512 , 1024 , 2048]):
        super(BiFPN_Layer , self).__init__()

        self.top_down = nn.ModuleList()
        self.bottom_up = nn.ModuleList()

        j = 0
        for i , channels in enumerate(in_channels_list):
            if i == 0 or i == len(in_channels_list)-1:
                if i == 0:
                    self.top_down.append(Conv(channels , in_channels_list[i+1] , stride=(2 , 2)))
            else :
                self.top_down.append(Conv(channels * 2 , in_channels_list[i+1] , stride=(2 , 2)))
        
        reversed_in_channels_list = Reverse(in_channels_list)

        for i , channels in enumerate(reversed_in_channels_list):
            if i == 0 or i == len(reversed_in_channels_list) -1 :
                if i == 0:
                    self.bottom_up.append(ConvT(channels*2 , out_channels[i]))
                else :
                    self.bottom_up.append(ConvT(channels + out_channels[i-1] , out_channels[i]))
            else :
                self.bottom_up.append(ConvT(channels * 2 + out_channels[i-1] , out_channels[i]))
    
    def forward(self , x):
        x_1 = []
        x_out = []
        j = 0
        for i , x_ in enumerate(x):
          
            if i == 0 or i == len(x)-1:
                if i == 0:
                    x_1.append(self.top_down[j](x_))
                    j += 1
                elif i == len(x) - 1:
                    lamp = 0
            else :
                #print(x[i].shape , x_1[-1].shape)
                temp = torch.cat([x[i] , x_1[-1]] , dim=1)
                x_1.append(self.top_down[j](temp))
                j+=1
        x_1_reversed = Reverse(x_1)
        x_reversed = Reverse(x)
        j = 0
        for i in range(len(x)):
            if i == 0 or i == len(x) - 1:
                if i == 0:
                    temp = torch.cat([x_reversed[i] , x_1_reversed[i]] , dim=1)
                    x_out.append(self.bottom_up[j](temp))
                    j += 1
                else :
                    temp = torch.cat([x_reversed[-1] , x_out[-1]] , dim=1)
                    x_out.append(self.bottom_up[j](temp))
            else :
                #print(x_out[-1].shape)
                x_out[-1] = F.resize(x_out[-1] , (x_1_reversed[i].shape[2] , x_1_reversed[i].shape[2]))
                #print(x_reversed[i].shape , x_1_reversed[i].shape , x_out[-1].shape)
                temp = torch.cat([x_reversed[i] , x_1_reversed[i] , x_out[-1]] , dim=1)
                x_out.append(self.bottom_up[j](temp))
                j += 1
        return x_out

In [None]:
'''version = 'b0'
bifpn = BiFPN_Layer(in_channels_list_[version] , 
                    Reverse(in_channels_list_[version]))
x = test(version = version)
z = bifpn(x)
for z_ in z:
    print(z_.shape)'''

In [21]:
class Repeated_BiFPN(nn.Module):
    def __init__(self , 
                 version , 
                 in_channels_list = [2048 , 1024 , 512 , 256 , 128] ,
                 out_channels_list_model = [128 , 256 , 512 , 1024 , 2048]):
        super(Repeated_BiFPN , self).__init__()

        version_list = {
            'b0':0 , 
            'b1':1 , 
            'b2':2 , 
            'b3':3 , 
            'b4':4 , 
            'b5':5 , 
            'b6':6 , 
            'b7':7
        }

        phi = version_list[version]
        
        repeats = 1 + phi
        
        out_channels_list = self._get_out_channels(in_channels_list , phi)

        self.layers = nn.ModuleList()

        for i , repeat in enumerate(range(repeats)):
            if i == repeats - 1:
                out_channels_list = out_channels_list_model
            self.layers.append(
                BiFPN_Layer(in_channels_list , out_channels_list)
            )
            in_channels_list = out_channels_list


    def _get_out_channels(self , in_channels_list , phi):
        out_channels_list_ = []
        out_channels_coe = 64 * 1.35 ** phi
        for channels in in_channels_list:
            out_channels_list_.append(
                int(channels * out_channels_coe)
            )
        return out_channels_list_

    def forward(self , x):
        for layer in self.layers:
            x = layer(x)
        return x

In [None]:
'''version = 'b0'
repeated_bifpn = Repeated_BiFPN(version , 
                                in_channels_list=in_channels_list_[version]).to(device)
x = test(version = version)
z = repeated_bifpn(x)
for z_ in z:
    print(z_.shape)'''

In [23]:
class EfficientDet(nn.Module):
    def __init__(self ,  
                 version , 
                 in_channels_list_ = in_channels_list_):
        super(EfficientDet , self).__init__()

        self.efficientnet = Efficient_Net(3 , version , out_channels_model=10)
        self.repeated_bifpn = Repeated_BiFPN(version , 
                                             in_channels_list = in_channels_list_[version] , 
                                             out_channels_list_model=[128 , 128 , 128 , 128 , 128])

    def _get_sum(self , x):
        x_out = 0
        for i , x_ in enumerate(x):
            x[i] = f.adaptive_avg_pool2d(x_ , (13 , 13))
            x_out += x[i]
        return x_out

    def forward(self , x):
        x = self.efficientnet(x)
        x = self.repeated_bifpn(x)
        x = self._get_sum(x)
        return x

In [None]:
'''def test_model(version):
    phi , res , drop_rate = phi_values[version]
    num_examples = 2
    x = torch.randn((num_examples , 3 , res , res)).to(device)
    efficientdet = EfficientDet(version)
    x = efficientdet(x)
    print(x.shape)'''

In [25]:
#test_model('b0')

In [26]:
class Pred(nn.Module):
    def __init__(self ,
                 version , 
                 in_channels = 128 , 
                 num_classes = 20 , 
                 B = 5 ,
                 S = 13):
        super(Pred , self).__init__()

        self.layers = nn.ModuleList()
        out_classes = num_classes + 4
        out_channels_model = out_classes * B

        version_list = {
            'b0':0 , 
            'b1':1 , 
            'b2':2 , 
            'b3':3 , 
            'b4':4 , 
            'b5':5 , 
            'b6':6 , 
            'b7':7
        }
        phi = version_list[version]

        repeats = int(3 + phi // 3)
        out_channels_coe = 64 * 1.35 ** phi
        out_channels = int(in_channels * out_channels_coe)

        for repeat in range(repeats):
            if repeat == repeats - 1:
                out_channels = out_channels_model
            
            self.layers.append(
                Conv(in_channels , out_channels)
            )
            in_channels = out_channels

    def forward(self , x):
        for layer in self.layers:
            x = layer(x)
        return x

In [None]:
'''x = torch.randn((2 , 128 , 13 , 13)).to(device)
pred = Pred('b0')
z = pred(x)
#print(z.shape)
z = z.view(z.shape[0] , 5 , 13 , 13 , 24)
print(z.shape)'''

In [28]:
class EfficientDet_Model(nn.Module):
    def __init__(self , 
                 version , 
                 B = 5 , 
                 C = 20 , 
                 S = 13):
        super(EfficientDet_Model , self).__init__()

        self.efficientdet = EfficientDet(version)
        self.pred = Pred(version)

        self.B = B
        self.C = C
        self.S = S

    def forward(self , x):
        x = self.efficientdet(x)
        x = self.pred(x)
        return x.view(x.shape[0] , self.B , self.S , self.S , self.C + 4)

In [29]:
def test_model(version):
    phi , res , drop_rate = phi_values[version]
    num_examples = 2
    x = torch.randn((num_examples , 3 , res , res)).to(device)
    efficientdet = EfficientDet_Model(version)
    x = efficientdet(x)
    print(x.shape)

In [30]:
#test_model("b0")

In [31]:
class Dataset_(torch.utils.data.Dataset):
    def __init__(self ,
                 img_dir , 
                 label_dir , 
                 csv_file , 
                 anchors , 
                 S = 13 , 
                 B = 5 , 
                 C = 20 , 
                 version = 'b0' , 
                 phi_values = phi_values):
        super(Dataset_ , self).__init__()

        self.img_dir = img_dir
        self.label_dir = label_dir
        self.df = pd.read_csv(csv_file)
        self.anchors = torch.from_numpy(np.array(anchors))
        #print(self.anchors)
        self.number_of_anchors_per_cell = 5
        self.ignore_iou_thresh = 0.5
        self.C = C
        self.S = S
        self.B = B
        phi , res , _ = phi_values[version]

        res = res

        self.transforms = transforms.Compose([
                                        transforms.ToPILImage() , 
                                        transforms.Resize((res , res)) , 
                                        transforms.ToTensor()
        ])


    def __len__(self):
        return len(self.df)
    
    def __getitem__(self , idx):
        label_path = os.path.join(self.label_dir , self.df.iloc[idx , 1])
        boxes = []

        with open(label_path) as f:
            for label in f.readlines():
                class_label , x , y , width , height = [
                    float(x) if float(x) != int(float(x)) else int(x)
                    for x in label.replace("\n", "").split()
                ]
                boxes.append([ x , y , width , height , class_label])

        boxes = torch.tensor(boxes) 

        img_path = os.path.join(self.img_dir , self.df.iloc[idx , 0])
        image = np.asarray(plt.imread(img_path))
        image = torch.from_numpy(image).permute(2 , 0 , 1)
        if self.transforms:
            image = self.transforms(image)

        targets = torch.zeros((self.B , self.S , self.S , 6))
        for box in boxes:
            iou_anchors = iou_width_height(box[2:4] , self.anchors)
            anchors_indices = iou_anchors.argsort(descending=True, dim=0)        
            x , y , width , height , class_label = box
            has_anchor = [False for _ in range(self.B)]
            for anchor_idx in anchors_indices:
                anchor_on_scale = anchor_idx % self.B
                S = self.S
                i , j = int(S * y) , int(S * x)
                anchor_taken = targets[anchor_on_scale , i , j , 0]
                if not anchor_taken and not has_anchor[anchor_on_scale]:
                    targets[anchor_on_scale , i , j , 0] = 1
                    x_cell , y_cell = S * x - j , S * y - i
                    width_cell , height_cell = (
                        width * S , 
                        height * S
                    )
                    box_coordinate = torch.tensor([x_cell , y_cell , width_cell , height_cell])
                    targets[anchor_on_scale , i , j , 1:5] = box_coordinate
                    targets[anchor_on_scale , i , j , 5] = int(class_label)
                    has_anchor[anchor_on_scale] = True

                elif not anchor_taken and iou_anchors[anchor_idx] > self.ignore_iou_thresh:
                    targets[anchor_on_scale , i , j , 0] = -1
        return image , targets

In [32]:
anchors = [[ 0.28, 0.22], [  0.38, 0.48], [ 0.9, 0.78], [ 0.07, 0.15], [ 0.15, 0.11]]
dataset = Dataset_(
    img_dir = '/content/drive/MyDrive/Yolo_Dataset/images/' , 
    label_dir = '/content/drive/MyDrive/Yolo_Dataset/labels' , 
    csv_file = '/content/drive/MyDrive/Yolo_Dataset/train.csv' , 
    anchors = anchors 
)
dataloader = torch.utils.data.DataLoader(dataset , batch_size = 2 , shuffle=True)

In [None]:
for x , y in dataloader:
    show_tensor_images(x)
    print(x.shape)
    print(y.shape)
    break

In [34]:
class Loss(nn.Module):
    def __init__(self):
        super(Loss , self).__init__()

        self.mse = nn.MSELoss()
        self.en = nn.CrossEntropyLoss()
        self.bce = nn.BCEWithLogitsLoss()
        self.sigmoid = nn.Sigmoid()


        self.lambda_class = 1
        self.lambda_noobj = 10
        self.lambda_obj = 1
        self.lambda_box = 10

    def forward(self , predictions , targets , anchors):
        anchors = torch.tensor(anchors)
        obj = targets[... , 0] == 1
        noobj = targets[... , 0] == 0

        no_obj_loss = self.mse(
            (predictions[... , 0:1][noobj]) , (targets[... , 0:1][noobj])
        )

        anchors = anchors.reshape(1 , 5 , 1 , 1 , 2)
        box_preds = torch.cat([self.sigmoid(predictions[... , 1:3]) , torch.exp(predictions[... , 3:5]) * anchors] , dim = -1)
        ious = intersection_over_union(box_preds[obj] , targets[... , 1:5][obj]).detach()
        object_loss = self.mse(self.sigmoid(predictions[... , 0:1][obj]) , ious * targets[... , 0:1][obj])

        predictions[... , 1:3] = self.sigmoid(predictions[... , 1:3])
        targets[..., 3:5] = torch.log(
            (1e-16 + targets[..., 3:5] / anchors)
        )  
        box_loss = self.mse(predictions[... , 1:5][obj] , targets[... , 1:5][obj])

        class_loss = self.en(
            (predictions[... , 5:][obj]) , (targets[... , 5][obj].long())
        )

        return (
            self.lambda_box * box_loss
            + self.lambda_obj * object_loss
            + self.lambda_noobj * no_obj_loss
            + self.lambda_class * class_loss
        )

In [None]:
'''version = 'b0'
loss_ = Loss().to(device)
efficientdet = EfficientDet_Model(version)
phi , res , _ = phi_values[version]
for x , y in dataloader:
    predictions = efficientdet(x)
    break
z = loss_(predictions , y , anchors)'''

In [39]:
version = 'b0'
efficientdet = EfficientDet_Model(version).to(device)
loss_ = Loss().to(device)
lr = 0.002
betas = (0.5 , 0.999)
opt = torch.optim.Adam(efficientdet.parameters() , lr=lr , betas = betas)
epochs = 200
display_steps = 100

In [40]:
def train():
    mean_loss = 0
    cur_step = 0

    for epoch in range(epochs):
        for x , y in tqdm(dataloader):
            x , y = x.to(device) , y.to(device)

            opt.zero_grad()
            y_ = efficientdet(x)
            loss = loss_(y_ , y , anchors)
            loss.backward()
            opt.step()
            
            mean_loss += loss.item() / display_steps
            if cur_step % display_steps == 0:
                print(f'Epoch {epoch} , Step {cur_step} , Mean YOLO Loss {mean_yolo_loss}')
            cur_step +=1
        mean_loss = 0

In [None]:
train()