In [1]:
from models.detector_model.model import ObjectDetectionModel
from models.detector_model.processor import TrainingProcessor
from models.detector_model.data_utils import TrainingDataset, COCOProcessor
from models.detector_model.new_loss import ObjectDetectionLoss
# from models.detector_model.loss import ObjectDetectionLoss
from torch.utils.data import DataLoader
import torch
import torch.optim as optim

grouped_classes = {
        "Metal": [
            "Metal bottle cap", "Metal lid", "Drink can", "Pop tab", "Scrap metal",
            "Food Can", "Aluminium blister pack", "Aluminium foil", "Aerosol"
        ],
        "Plastic": [
            "Plastic bottle cap", "Other plastic wrapper", "Six pack rings",
            "Single-use carrier bag", "Plastic straw", "Plastic glooves",
            "Plastic utensils", "Disposable plastic cup", "Other plastic bottle",
            "Tupperware", "Spread tub", "Garbage bag", "Other plastic container",
            "Other plastic", "Rope & strings", "Other plastic cup", "Plastic film",
            "Polypropylene bag", "Plastic lid", "Clear plastic bottle", "Squeezable tube",
            "Carded blister pack", "Crisp packet", "Meal carton"
        ],
        "Paper": [
            "Paper cup", "Paper bag", "Normal paper", "Paper straw", "Tissues",
            "Toilet tube", "Wrapping paper", "Pizza box", "Magazine paper",
            "Corrugated carton", "Egg carton", "Other carton", "Drink carton"
        ],
        "Glass": [
            "Glass jar", "Glass bottle", "Glass cup", "Broken glass"
        ],
        "Waste": [
            "Cigarette", "Food waste", "Foam cup",
            "Disposable food container", "Foam food container",
            "Shoe", "Unlabeled litter", "Styrofoam piece"
        ],
        "Battery": [
            "Battery"
        ],
    }

In [2]:
model = ObjectDetectionModel(num_classes=len(grouped_classes), num_anchors=4, grid_size=4)
model.count_parameters()
coco_processor = COCOProcessor(classes=grouped_classes)

MODEL PARAMETER SUMMARY
Total parameters:      390,700
Trainable parameters:  390,700
Non-trainable params:  0


In [3]:
extracted_trash = coco_processor.extract_annotations(
    'D:/Sakal/AI_FARM/Recycling_Classification/Dataset/Dataset/Trash Detection.v14i.coco/train/_annotations.coco.json',
    'D:/Sakal/AI_FARM/Recycling_Classification/Dataset/Dataset/Trash Detection.v14i.coco/train',
    convert=False
)

extracted_taco = coco_processor.extract_annotations(
    'D:/Sakal/AI_FARM/Recycling_Classification/Dataset/TACO/data/annotations.json',
    'D:/Sakal/AI_FARM/Recycling_Classification/Dataset/TACO/data',
    convert=True
)

classes_names_trash = []
for label in extracted_trash:
    classes_names_trash.extend(label['Class'])
classes_names_trash = list(set(classes_names_trash))

classes_names_taco = []
for label in extracted_taco:
    classes_names_taco.extend(label['Class'])
classes_names_taco = list(set(classes_names_taco))

In [4]:
from PIL import Image
classes = [item for item, value in grouped_classes.items()] # ['Metal', 'Plastic', 'Paper', 'Glass', 'Waste', 'Battery']

processor = TrainingProcessor(
    input_size=448,
    grid_size=model.grid_size,
    num_anchors=model.num_anchors,
    classes=classes,
)

trash_dataset = TrainingDataset(data_json=extracted_trash, processor=processor, is_training=False)
trash_dataloader = DataLoader(trash_dataset, batch_size=40, shuffle=True)


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = ObjectDetectionLoss(processor=processor)
optimizer = optim.Adam(
    model.parameters(),
    lr=1e-4,           # learning rate
    betas=(0.9, 0.999),# beta1 and beta2 for momentum estimates
    eps=1e-8,          # small constant for numerical stability
    weight_decay=0     # L2 regularization
)

num_epochs = 50
batch_interval = 50

In [6]:
# image_tensor, target_tensor, anchor_pose = processor.process_training_sample(
#     extracted_trash[100], apply_augmentation=False, get_anchors=True)

# with torch.no_grad():
#     output = model(image_tensor.unsqueeze(0).to(device))

# bboxes = processor.convert_yolo_output_to_bboxes(output[0], grid=True, class_tensor=True, conf_threshold=None)

In [7]:
# criterion._compute_single_image_loss_global(pred_data=model_bboxes, gt_data=bboxes, device='cuda')

In [8]:
# monitor = train_with_monitoring(
#     model=model,
#     dataloader=trash_dataloader,
#     loss_fn=criterion,
#     optimizer=optimizer,
#     num_epochs=50,
#     save_model_path='best_model.pth',
#     monitor_frequency=200, 
#     device='cuda'
# )

In [None]:
model.to(device)
training_lifetime_loss = 0.0
training_lifetime_batch = 0

for epoch in range(num_epochs):
    epoch_obj_loss = []
    epoch_cls_loss = []
    epoch_siou_loss = []


    model.train()  # Ensure model is in training mode
    epoch_loss = 0.0
    batch_interval_loss = 0.0
    num_batches = 0
    
    for i, (x, y) in enumerate(trash_dataloader):
        x, y = x.to(device), y.to(device)
        
        optimizer.zero_grad()
        y_pred = model(x)
        loss = criterion(y_pred, y)
        loss_item = loss['total']
        
        epoch_obj_loss.append(loss['objectness'])
        epoch_cls_loss.append(loss['classification'])
        epoch_siou_loss.append(loss['siou'])

        # Check for invalid loss values
        if torch.isnan(loss_item) or torch.isinf(loss_item):
            print(f"Warning: Invalid loss detected at epoch {epoch+1}, batch {i}")
            print(f"Loss value: {loss_item.item()}")
            continue  # Skip this batch
        
        loss_item.backward()
        
        # Add gradient clipping for stability (especially important for your loss function)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
        
        optimizer.step()
        
        # Accumulate losses
        loss_item = loss_item.item()
        epoch_loss += loss_item
        batch_interval_loss += loss_item
        training_lifetime_loss += loss_item
        num_batches += 1
        training_lifetime_batch += 1
        
        # Print batch interval statistics
        if i % (batch_interval) == 0 and i != 0:
            avg_interval_loss = batch_interval_loss / (batch_interval+1)  # Fixed division
            avg_epoch_loss_so_far = epoch_loss / num_batches
            avg_lifetime_loss = training_lifetime_loss / training_lifetime_batch
            
            print(f'\tBatch: [{i}/{len(trash_dataloader)}], '
                  f'Interval Loss: {avg_interval_loss:.4f}, '
                  f'Epoch Loss: {avg_epoch_loss_so_far:.4f}, '
                  f'Lifetime Loss: {avg_lifetime_loss:.4f}')
            print(f'\t\tSIoU Loss: {sum(epoch_siou_loss) / len(epoch_siou_loss)}, Objectness Loss: {sum(epoch_obj_loss) / len(epoch_obj_loss)}, Class Loss: {sum(epoch_cls_loss) / len(epoch_cls_loss)},')
            
            batch_interval_loss = 0.0
    
    # Epoch summary
    if num_batches > 0:  # Avoid division by zero
        avg_epoch_loss = epoch_loss / num_batches
        avg_lifetime_loss = training_lifetime_loss / training_lifetime_batch
        
        print(f'Epoch: [{epoch+1}/{num_epochs}], '
              f'Avg Epoch Loss: {avg_epoch_loss:.4f}, '
              f'Avg Lifetime Loss: {avg_lifetime_loss:.4f}')
        
        # Optional: Save checkpoint periodically
        if (epoch + 1) % 25 == 0: 
            checkpoint = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_epoch_loss,
                'lifetime_loss': avg_lifetime_loss
            }
            torch.save(checkpoint, f'checkpoint_epoch_{epoch+1}.pth')
            print(f"Checkpoint saved at epoch {epoch+1}")

print("Training completed!")
print(f"Final average lifetime loss: {training_lifetime_loss / training_lifetime_batch:.4f}")

	Batch: [50/150], Interval Loss: 1.8644, Epoch Loss: 1.8644, Lifetime Loss: 1.8644
		SIoU Loss: 0.770252525806427, Objectness Loss: 0.08339868485927582, Class Loss: 0.07854638993740082,
	Batch: [100/150], Interval Loss: 1.6062, Epoch Loss: 1.7525, Lifetime Loss: 1.7525
		SIoU Loss: 0.7508613467216492, Objectness Loss: 0.061175450682640076, Class Loss: 0.06420508772134781,
Epoch: [1/50], Avg Epoch Loss: 1.6882, Avg Lifetime Loss: 1.6882
	Batch: [50/150], Interval Loss: 1.5343, Epoch Loss: 1.5343, Lifetime Loss: 1.6491
		SIoU Loss: 0.7091919183731079, Objectness Loss: 0.01916486769914627, Class Loss: 0.03877779468894005,
	Batch: [100/150], Interval Loss: 1.4471, Epoch Loss: 1.5055, Lifetime Loss: 1.6147
		SIoU Loss: 0.6972808241844177, Objectness Loss: 0.017703143879771233, Class Loss: 0.037744034081697464,
Epoch: [2/50], Avg Epoch Loss: 1.4993, Avg Lifetime Loss: 1.5937
	Batch: [50/150], Interval Loss: 1.4470, Epoch Loss: 1.4470, Lifetime Loss: 1.5724
		SIoU Loss: 0.6737440824508667, Ob