In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import torchvision
from torchvision import datasets, models, transforms, tv_tensors
import torch.utils.data
from torchvision.transforms import v2
from torchvision.io import read_image
from torch.nn.utils.rnn import pad_sequence
from torchvision.utils import draw_bounding_boxes

import numpy as np
import matplotlib.pyplot as plt
import time
import os
from PIL import Image
from tempfile import TemporaryDirectory
plt.ion()   # interactive mode

<contextlib.ExitStack at 0x1a6f6e3d250>

In [2]:
#from torchvision.io.image import decode_image
from torchvision.models.segmentation import deeplabv3_resnet50, DeepLabV3_ResNet50_Weights #fcn_resnet50, FCN_ResNet50_Weights
from torchvision.models.segmentation import lraspp_mobilenet_v3_large, LRASPP_MobileNet_V3_Large_Weights
from torchvision.models.segmentation.lraspp import LRASPPHead
from torchvision.transforms.functional import to_pil_image

In [3]:
import train as engine

In [4]:
import coco_utils, presets, transforms, utils, v2_extras

In [5]:
torch.manual_seed(0)

<torch._C.Generator at 0x1a6eec14770>

In [6]:
torch.cuda.is_available()

True

In [7]:
device=torch.device('cuda')
print(f"Using {device} device")

Using cuda device


In [8]:
cudnn.enabled = True

In [9]:
cudnn.benchmark = False
cudnn.deterministic = True

In [10]:
mweights=torch.load('mobilenet.pth',device)

  mweights=torch.load('mobilenet.pth',device)


In [11]:
mweights

OrderedDict([('features.0.0.weight',
              tensor([[[[-1.3784e+00, -1.6226e+00, -1.6820e+00],
                        [-1.3891e+00, -1.3942e+00, -1.4535e+00],
                        [-1.5137e+00, -1.5199e+00, -1.4447e+00]],
              
                       [[-9.7659e-01, -1.0530e+00, -9.6557e-01],
                        [-7.5065e-01, -8.0721e-01, -7.1942e-01],
                        [-9.3843e-01, -8.1846e-01, -8.6168e-01]],
              
                       [[ 5.9471e-01,  7.1883e-01,  7.4485e-01],
                        [ 6.8496e-01,  7.4768e-01,  7.5413e-01],
                        [ 4.9358e-01,  6.9060e-01,  5.4507e-01]]],
              
              
                      [[[ 3.7429e-01,  9.5659e-01,  1.8253e-01],
                        [ 1.7528e-01,  4.7656e-01, -2.8317e-01],
                        [ 7.2440e-01,  1.3856e+00,  9.3563e-01]],
              
                       [[ 1.2561e-01,  6.3788e-01, -1.0329e-01],
                        [-2.0830e-01, 

In [12]:
class_names=['__background__','F-15 Eagle',
 'F-22 Raptor',
 'a-10 thunderbolt ii',
 'ac-130 ghostrider',
 'b-1 lancer',
 'b-2 spirit',
 'b-52 stratofortress',
 'f-35 lightning ii']

In [13]:
model = lraspp_mobilenet_v3_large(weights=None,weights_backbone=mweights)



In [14]:
model.classifier.high_classifier.out_channels= len(class_names)
model.classifier.low_classifier.out_channels= len(class_names)

In [15]:
model

LRASPP(
  (backbone): IntermediateLayerGetter(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): Hardswish()
    )
    (1): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (2): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 64, kernel_size=(1, 1), stride=(1

In [16]:
train="data/av_sem/train"
train_ann="data/av_sem/train/_annotations.coco.json"

In [17]:
val="data/av_sem/valid"
val_ann="data/av_sem/valid/_annotations.coco.json"

In [18]:
 import v2_extras
from torchvision.datasets import wrap_dataset_for_transforms_v2

transforms_ = v2.Compose([v2_extras.CocoDetectionToVOCSegmentation(),
                        v2.RandomResizedCrop((520), antialias=True),
    v2.RandomHorizontalFlip(),
    #v2.CenterCrop(520),
    v2.PILToTensor(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])


In [19]:
transforms_

Compose(
      CocoDetectionToVOCSegmentation()
      RandomResizedCrop(size=(520, 520), scale=(0.08, 1.0), ratio=(0.75, 1.3333333333333333), interpolation=InterpolationMode.BILINEAR, antialias=True)
      RandomHorizontalFlip(p=0.5)
      PILToTensor()
      ToDtype(scale=True)
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False)
)

In [20]:
dataset = torchvision.datasets.CocoDetection(train, train_ann, transforms=transforms_)
dataset = wrap_dataset_for_transforms_v2(dataset, target_keys={"masks", "labels"})

loading annotations into memory...
Done (t=0.04s)
creating index...
index created!


In [21]:
dataset.coco.cats

{0: {'id': 0, 'name': 'air-force-air-vehicles', 'supercategory': 'none'},
 1: {'id': 1, 'name': 'a-10', 'supercategory': 'air-force-air-vehicles'},
 2: {'id': 2, 'name': 'ac-130', 'supercategory': 'air-force-air-vehicles'},
 3: {'id': 3, 'name': 'b-1', 'supercategory': 'air-force-air-vehicles'},
 4: {'id': 4, 'name': 'b-2', 'supercategory': 'air-force-air-vehicles'},
 5: {'id': 5, 'name': 'b-52', 'supercategory': 'air-force-air-vehicles'},
 6: {'id': 6, 'name': 'f-15', 'supercategory': 'air-force-air-vehicles'},
 7: {'id': 7, 'name': 'f-22', 'supercategory': 'air-force-air-vehicles'},
 8: {'id': 8, 'name': 'f-35', 'supercategory': 'air-force-air-vehicles'}}

In [22]:
val_ds = torchvision.datasets.CocoDetection(val,val_ann, transforms=transforms_)
val_ds = wrap_dataset_for_transforms_v2(val_ds, target_keys=["masks","labels"])

loading annotations into memory...
Done (t=0.01s)
creating index...
index created!


In [23]:
train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.RandomSampler(val_ds)

In [24]:
def coll_fn(data):
    image=next(iter(data))[0]
    #boxes=next(iter(dataset))[1]['boxes'].type(torch.LongTensor)
    masks=next(iter(data))[1].type(torch.LongTensor)
    #labels=next(iter(dataset))[1]['labels'].type(torch.LongTensor)
    #bboxes=torch.Tensor(next(iter(dataset))[1]['bbox'])

    
    image = pad_sequence(image, batch_first=True)

    #data_dict= {'boxes_':boxes,'masks_':masks,'labels_':labels,'bbox_':bboxes}
    #data_list=[boxes,masks,labels,bboxes]
    return image.unsqueeze(0), masks.unsqueeze(0)
    

In [25]:
data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=8,
    #num_workers=16,
    #shuffle=True,
    sampler=train_sampler,
    # We need a custom collation function here, since the object detection
    # models expect a sequence of images and target dictionaries. The default
    # collation function tries to torch.stack() the individual elements,
    # which fails in general for object detection, because the number of bounding
    # boxes varies between the images of the same batch.
    collate_fn= utils.collate_fn#lambda batch: tuple(zip(*batch)) 
)

In [26]:
val_data_loader = torch.utils.data.DataLoader(
    val_ds,
    batch_size=8,
    #num_workers=16,
    #shuffle=False,
    sampler=test_sampler,
    # We need a custom collation function here, since the object detection
    # models expect a sequence of images and target dictionaries. The default
    # collation function tries to torch.stack() the individual elements,
    # which fails in general for object detection, because the number of bounding
    # boxes varies between the images of the same batch.
    collate_fn= utils.collate_fn#lambda batch: tuple(zip(*batch))
)

In [27]:
count=5
epochs=15
# total_correct = 0
# total_samples = 0
# running_loss = 0.0

In [28]:
criterion = nn.CrossEntropyLoss().cuda()
params = [p for p in model.parameters() if p.requires_grad]
# Observe that all parameters are being optimized
optimizer = optim.SGD(params, lr=0.01, momentum=0.9, weight_decay=1e-4)
scaler=torch.amp.GradScaler('cuda',enabled=True)
# Decay LR by a factor of 0.1 every 7 epochs
#lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
lr_scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer,total_iters=len(data_loader) * (epochs-count), power=0.9)


In [29]:
def evaluate(model,criterion,loader,device):
    model.eval()
    total_correct = 0
    total_samples = 0
    running_loss = 0.0
    with torch.inference_mode():
        for image, target in loader:
            image, target = image.to(device), target.to(device)
            output = model(image)
            loss = criterion(output, target)
            _, predicted = torch.max(output['out'], 1)
            total_correct += (predicted == target).sum().item()
            total_samples += target.size(0)
            running_loss += loss.item() * image.size(0)
        val_loss = running_loss / len(loader.dataset)
        print(f'Epoch {epoch}: Validation Loss = {val_loss:.5f}')
        #accuracy =  total_correct / total_samples
        #print(f'Epoch {epoch+1}: Val Accuracy = {accuracy:.5f}')
        # print(f'Epoch {epoch}: Loss = {val_loss:.5f}')

In [30]:
checkpoint = torch.load('models/c_lraspp_custom.pt', weights_only=True)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']

In [31]:
model.to(device)

LRASPP(
  (backbone): IntermediateLayerGetter(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): Hardswish()
    )
    (1): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (2): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 64, kernel_size=(1, 1), stride=(1

In [32]:
for epoch in range(epochs):
    # train for one epoch, printing every 10 iterations
    engine.train_one_epoch(model, engine.criterion, optimizer, data_loader, lr_scheduler, device, epoch, print_freq=8, scaler=scaler)
    evaluate(model,engine.criterion,val_data_loader,device)
    count+=1
    torch.save(model,'models/lraspp_custom.pt')
    torch.save(model.state_dict(),'models/lraspp_custom.pth')
    # update the learning rate
    #lr_scheduler.step()
    # evaluate on the test dataset
    #engine.evaluate(model, val_data_loader, device, len(class_names))

  with torch.cuda.amp.autocast(enabled=scaler is not None):


Epoch: [0]  [ 0/48]  eta: 0:01:57  lr: 0.0  loss: 0.1400 (0.1400)  time: 2.4523  data: 0.1710  max mem: 1983
Epoch: [0]  [ 8/48]  eta: 0:06:43  lr: 0.0  loss: 0.1734 (0.2103)  time: 10.0969  data: 0.1784  max mem: 1983
Epoch: [0]  [16/48]  eta: 0:06:30  lr: 0.0  loss: 0.1734 (0.2020)  time: 12.2130  data: 0.1531  max mem: 1983
Epoch: [0]  [24/48]  eta: 0:05:08  lr: 0.0  loss: 0.1614 (0.1865)  time: 14.4627  data: 0.1427  max mem: 1983
Epoch: [0]  [32/48]  eta: 0:03:31  lr: 0.0  loss: 0.1759 (0.2003)  time: 14.1793  data: 0.1399  max mem: 1983
Epoch: [0]  [40/48]  eta: 0:01:50  lr: 0.0  loss: 0.1818 (0.1970)  time: 14.9386  data: 0.1305  max mem: 1983
Epoch: [0] Total time: 0:10:54
Epoch 0: Validation Loss = 0.22853
Epoch: [1]  [ 0/48]  eta: 0:08:07  lr: 0.0  loss: 0.1396 (0.1396)  time: 10.1486  data: 0.1576  max mem: 1983
Epoch: [1]  [ 8/48]  eta: 0:07:11  lr: 0.0  loss: 0.1399 (0.2767)  time: 10.7976  data: 0.1246  max mem: 1983
Epoch: [1]  [16/48]  eta: 0:05:45  lr: 0.0  loss: 0.165

In [33]:
torch.save(model.state_dict(),'models/lraspp_custom.pth')

In [34]:
torch.save(model,'models/lraspp_custom.pt')

In [35]:
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            #'loss': loss,
            }, 'models/c_lraspp_custom.pt')

# error corrections