In [1]:
from models.detector_model.model import ObjectDetectionModel
from models.detector_model.processor import TrainingProcessor
from models.detector_model.data_utils import TrainingDataset, COCOProcessor
from models.detector_model.loss_grid import ObjectDetectionLoss
from torch.utils.data import DataLoader
import torch
import torch.optim as optim

grouped_classes = {
        "Metal": [
            "Metal bottle cap", "Metal lid", "Drink can", "Pop tab", "Scrap metal",
            "Food Can", "Aluminium blister pack", "Aluminium foil", "Aerosol"
        ],
        "Plastic": [
            "Plastic bottle cap", "Other plastic wrapper", "Six pack rings",
            "Single-use carrier bag", "Plastic straw", "Plastic glooves",
            "Plastic utensils", "Disposable plastic cup", "Other plastic bottle",
            "Tupperware", "Spread tub", "Garbage bag", "Other plastic container",
            "Other plastic", "Rope & strings", "Other plastic cup", "Plastic film",
            "Polypropylene bag", "Plastic lid", "Clear plastic bottle", "Squeezable tube",
            "Carded blister pack", "Crisp packet", "Meal carton"
        ],
        "Paper": [
            "Paper cup", "Paper bag", "Normal paper", "Paper straw", "Tissues",
            "Toilet tube", "Wrapping paper", "Pizza box", "Magazine paper",
            "Corrugated carton", "Egg carton", "Other carton", "Drink carton"
        ],
        "Glass": [
            "Glass jar", "Glass bottle", "Glass cup", "Broken glass"
        ],
        "Waste": [
            "Cigarette", "Food waste", "Foam cup",
            "Disposable food container", "Foam food container",
            "Shoe", "Unlabeled litter", "Styrofoam piece"
        ],
        "Battery": [
            "Battery"
        ],
    }

In [2]:
model = ObjectDetectionModel(num_classes=len(grouped_classes), num_anchors=3, grid_size=3)
# model.load_state_dict(torch.load('D:\Sakal\AI_FARM\Recycling_Classification\src\checkpoint_latest.pth')['model_state_dict'])
model.count_parameters()
coco_processor = COCOProcessor(classes=grouped_classes)

MODEL PARAMETER SUMMARY
Total parameters:      416,801
Trainable parameters:  416,801
Non-trainable params:  0


In [3]:
extracted_trash = coco_processor.extract_annotations(
    'D:/Sakal/AI_FARM/Recycling_Classification/Dataset/Dataset/Trash Detection.v14i.coco/train/_annotations.coco.json',
    'D:/Sakal/AI_FARM/Recycling_Classification/Dataset/Dataset/Trash Detection.v14i.coco/train',
    convert=False
)

extracted_taco = coco_processor.extract_annotations(
    'D:/Sakal/AI_FARM/Recycling_Classification/Dataset/TACO/data/annotations.json',
    'D:/Sakal/AI_FARM/Recycling_Classification/Dataset/TACO/data',
    convert=True
)

classes_names_trash = []
for label in extracted_trash:
    classes_names_trash.extend(label['Class'])
classes_names_trash = list(set(classes_names_trash))

classes_names_taco = []
for label in extracted_taco:
    classes_names_taco.extend(label['Class'])
classes_names_taco = list(set(classes_names_taco))

In [4]:
dataset = extracted_taco + extracted_trash
len(dataset)

7500

In [5]:
from PIL import Image
classes = [item for item, value in grouped_classes.items()] # ['Metal', 'Plastic', 'Paper', 'Glass', 'Waste', 'Battery']

processor = TrainingProcessor(
    input_size=448,
    grid_size=model.grid_size,
    num_anchors=model.num_anchors,
    classes=classes,
)

trash_dataset = TrainingDataset(data_json=dataset, processor=processor, is_training=False)
trash_dataloader = DataLoader(trash_dataset, batch_size=40, shuffle=True)


In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = ObjectDetectionLoss(processor=processor, bbox_loss_weight=1.0, cls_loss_weight=2.0, obj_loss_weight=2.0, pos_obj_weight=1.5, neg_obj_weight=0.5)
optimizer = optim.Adam(
    model.parameters(),
    lr=1e-4,           # learning rate
    betas=(0.9, 0.999),# beta1 and beta2 for momentum estimates
    eps=1e-8,          # small constant for numerical stability
    weight_decay=0     # L2 regularization
)

num_epochs = 50
batch_interval = 25

In [None]:
# image_tensor, target_tensor, anchor_pose = processor.process_training_sample(
#     extracted_trash[1530], apply_augmentation=False, get_anchors=True)

# model.to(device)
# with torch.no_grad():
#     output = model(image_tensor.unsqueeze(0).to(device))

# bboxes = processor.convert_output_to_bboxes(output[0], grid=True, class_tensor=True, conf_threshold=None)

# neg_processed_bboxes = []
# for item in bboxes:
#     bbox = {
#         'bbox': item['bbox'],
#         'conf': torch.sigmoid(item['conf']),
#         'class_tensor': torch.sigmoid(item['class_tensor']),
#         'grid': item['grid'],
#         'class_id': torch.argmax(torch.sigmoid(item['class_tensor']))
#     }
#     if torch.sigmoid(item['conf']) > 0.32:
#         neg_processed_bboxes.append(bbox)

# neg_processed_bboxes

In [8]:
# processor.draw_bbox_on_image(image_tensor, neg_processed_bboxes)

In [9]:
# monitor = train_with_monitoring(
#     model=model,
#     dataloader=trash_dataloader,
#     loss_fn=criterion,
#     optimizer=optimizer,
#     num_epochs=50,
#     save_model_path='best_model.pth',
#     monitor_frequency=200, 
#     device='cuda'
# )

In [None]:
model.to(device)
training_lifetime_loss = 0.0
training_lifetime_batch = 0

for epoch in range(num_epochs):
    epoch_obj_loss = []
    epoch_cls_loss = []
    epoch_ciou_loss = []
    epoch_pos_obj_loss = []


    model.train() 
    epoch_loss = 0.0
    batch_interval_loss = 0.0
    num_batches = 0
    
    for i, (x, y) in enumerate(trash_dataloader):
        x, y = x.to(device), y.to(device)
        
        optimizer.zero_grad()
        y_pred = model(x)
        loss = criterion(y_pred, y)
        loss_item = loss['total']
        
        epoch_obj_loss.append(loss['objectness'])
        epoch_cls_loss.append(loss['classification'])
        epoch_ciou_loss.append(loss['ciou'])
        try:
            epoch_pos_obj_loss.append(loss['positive_objectness'])
        except:
            pass

        # Check for invalid loss values
        if torch.isnan(loss_item) or torch.isinf(loss_item):
            print(f"Warning: Invalid loss detected at epoch {epoch+1}, batch {i}")
            print(f"Loss value: {loss_item.item()}")
            continue  
        
        loss_item.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
        optimizer.step()
        
        # Accumulate losses
        loss_item = loss_item.item()
        epoch_loss += loss_item
        batch_interval_loss += loss_item
        training_lifetime_loss += loss_item
        num_batches += 1
        training_lifetime_batch += 1
        
        # Print batch interval statistics
        if i % (batch_interval) == 0 and i != 0:
            avg_interval_loss = batch_interval_loss / (batch_interval+1)  # Fixed division
            avg_epoch_loss_so_far = epoch_loss / num_batches
            avg_lifetime_loss = training_lifetime_loss / training_lifetime_batch
            
            print(f'\tBatch: [{i}/{len(trash_dataloader)}], '
                  f'Interval Loss: {avg_interval_loss:.4f}, '
                  f'Epoch Loss: {avg_epoch_loss_so_far:.4f}, '
                  f'Lifetime Loss: {avg_lifetime_loss:.4f}')
            print(f'\t\tciou Loss: {sum(epoch_ciou_loss) / len(epoch_ciou_loss)}, Objectness Loss: {sum(epoch_obj_loss) / len(epoch_obj_loss)}, Positive Objectness Loss: {sum(epoch_pos_obj_loss) / len(epoch_pos_obj_loss)}, Class Loss: {sum(epoch_cls_loss) / len(epoch_cls_loss)},')
            
            batch_interval_loss = 0.0
    
    # Epoch summary
    if num_batches > 0:
        avg_epoch_loss = epoch_loss / num_batches
        avg_lifetime_loss = training_lifetime_loss / training_lifetime_batch
        
        print(f'Epoch: [{epoch+1}/{num_epochs}], '
              f'Avg Epoch Loss: {avg_epoch_loss:.4f}, '
              f'Avg Lifetime Loss: {avg_lifetime_loss:.4f}')
        
        if (epoch + 1) % 25 == 0: 
            checkpoint = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_epoch_loss,
                'lifetime_loss': avg_lifetime_loss
            }
            torch.save(checkpoint, f'checkpoint_epoch_{epoch+1}.pth')
            print(f"Checkpoint saved at epoch {epoch+1}")

print("Training completed!")
print(f"Final average lifetime loss: {training_lifetime_loss / training_lifetime_batch:.4f}")

	Batch: [25/188], Interval Loss: 0.5726, Epoch Loss: 0.5726, Lifetime Loss: 0.5726
		ciou Loss: 0.2456081360578537, Objectness Loss: 0.20305056869983673, Positive Objectness Loss: 0.6890256404876709, Class Loss: 0.12390400469303131,
	Batch: [50/188], Interval Loss: 0.5544, Epoch Loss: 0.5745, Lifetime Loss: 0.5745
		ciou Loss: 0.2485014945268631, Objectness Loss: 0.20168587565422058, Positive Objectness Loss: 0.678826093673706, Class Loss: 0.12436062097549438,
	Batch: [75/188], Interval Loss: 0.4851, Epoch Loss: 0.5515, Lifetime Loss: 0.5515
		ciou Loss: 0.23841601610183716, Objectness Loss: 0.19394990801811218, Positive Objectness Loss: 0.6671144366264343, Class Loss: 0.11914939433336258,
	Batch: [100/188], Interval Loss: 0.4454, Epoch Loss: 0.5297, Lifetime Loss: 0.5297
		ciou Loss: 0.2288467139005661, Objectness Loss: 0.18687878549098969, Positive Objectness Loss: 0.6595014929771423, Class Loss: 0.11393117159605026,
	Batch: [125/188], Interval Loss: 0.4250, Epoch Loss: 0.5123, Lifet