In [None]:
import os
import json
import numpy as np
import random
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import torch
from model import Detector
from PIL import Image
from torch.utils.data import DataLoader
from utils import file_format_counter
from dataloader import SARDet
from torchsummary import summary

Data Exploration

In [None]:
data_dir = "/home/hasanmog/CNN-VS-ViT/Datasets/SARDet"
train = os.listdir(os.path.join(data_dir + '/train'))
val = os.listdir(os.path.join(data_dir + '/val'))
len(train) , len(val)

In [None]:
random.seed(50)
random.shuffle(train)
test = train[70000:]
train = train[:70000]
len(train) , len(test)

In [None]:
index = 6000
img = Image.open(os.path.join(data_dir+'/train/'+train[index]))
print(img)
plt.imshow(img)

In [None]:
png , jpg ,bmp = file_format_counter(train)         
print(f"train_set: png={png} , jpg={jpg} , bmp={bmp}")
png , jpg ,bmp = file_format_counter(val)         
print(f"val_set: png={png} , jpg={jpg} , bmp={bmp}")
png , jpg ,bmp = file_format_counter(test)         
print(f"test_set: png={png} , jpg={jpg} , bmp={bmp}")      

In [None]:
print(train[0])

In [None]:
train_json = os.path.join(data_dir+'/train.json')
val_json = os.path.join(data_dir+'/val.json')

with open(train_json , 'r') as file:
    train_anno = json.load(file)
    
train_anno.keys() , train_anno['images'][0] , train_anno['annotations'][0] , train_anno['categories']

In [None]:
train_set = SARDet(data_dir= data_dir , imgs = train , mode = 'train')
val_set = SARDet(data_dir= data_dir , imgs = val , mode = 'val')
test_set = SARDet(data_dir= data_dir , imgs = test , mode = 'test')
len(train_set) , len(val_set) , len(test_set)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 8

In [None]:
train_loader = DataLoader(dataset = train_set , batch_size = BATCH_SIZE , shuffle = True)
val_loader = DataLoader(dataset = val_set , batch_size = BATCH_SIZE , shuffle = True)
test_loader = DataLoader(dataset = test_set , batch_size = BATCH_SIZE , shuffle = False)
len(train_loader),len(val_loader),len(test_loader)

In [None]:
sample = val_set[999]
sample

In [None]:
image = sample['image_tensor'].numpy()
image = np.transpose(image , (1 , 2 , 0))
print("image size : " , image.shape)

fig , ax = plt.subplots(1)
ax.imshow(image , cmap='viridis')

for i in range(len(sample['bboxes'])):
    box = sample['bboxes'][i]
    rect = patches.Rectangle((box[0], box[1]), box[2], box[3], linewidth=3, edgecolor='r', facecolor='none')
    plt.gca().add_patch(rect)
    
plt.show()

In [None]:
model = Detector().to(device)
summary(model , input_size=(3,256 , 256))

In [None]:
input_array = np.random.randint(0, 1, (800, 800, 3))  # Generate a random 800x800 image with 3 channels (RGB)
input_tensor = torch.from_numpy(input_array).float()  # Convert to float tensor
print(input_tensor.shape)
input_tensor = input_tensor.permute(2, 0, 1).unsqueeze(0)  

input_tensor = input_tensor.to(device)
# Pass the input tensor through the model
outputs = model(input_tensor)  # Ensure model is in evaluation mode if not training: model.eval()
outputs[: , : , 0 , 0] , outputs.shape

In [None]:
from postprocessing import convert_to_mins_maxes , non_max_suppression , process_boxes
from model import decode_outputs

boxes , object , class_scores = decode_outputs(outputs) 
print(f"After decoding : {boxes.shape} , {object.shape} , {class_scores.shape}")
boxes = boxes.reshape(-1, 4)
class_scores = class_scores.reshape(-1, 6)
print(boxes.shape)
print(class_scores.shape)
assert boxes.shape[0] == class_scores.shape[0], "Mismatch in bounding boxes and class scores counts"
picked_boxes, picked_scores, picked_classes = non_max_suppression(boxes,class_scores)
len(picked_boxes) , len(picked_scores) , len(picked_classes)

In [None]:
from engine import train

train(model = model , train_loader=train_loader , val_loader=val_loader ,
      lr = 0.001 , lr_schedule = 'exponential' , epochs = 2 , 
      out_dir = None , device = device )