# Imports

In [1]:
import time
import torch
from torch import nn, optim
from tqdm import tqdm
from dataset import WheatSegDatasetDETR
from Detr import DETR
import cv2
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
import wandb
import os
from torch.utils.data import DataLoader, Subset
from segmentation_models_pytorch.losses import TverskyLoss
# Import TverskyLoss from torch

from utils import *
from definitions import *

# Select MPS if available, otherwise CPU
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
print(f"Using device: {device}")
DETR_MODEL_SAVE_PATH = "detr.pth"
CSV_PATH = "/Users/guyperets/Documents/DeepLearning_ImageProcessing/FinalProject/final_IP_project/data_detr/train.csv"
IMAGES_DIR_PATH = "/Users/guyperets/Documents/DeepLearning_ImageProcessing/FinalProject/final_IP_project/data_detr/train"

Using device: mps


# DataLoaders, Model, Loss & Optimizer

In [2]:
full_dataset = WheatSegDatasetDETR(csv_path=CSV_PATH, images_dir=IMAGES_DIR_PATH)
# Split to Train and Validation using VAL_RATIO
val_indices = np.random.choice(
    len(full_dataset),
    size=int(len(full_dataset) * VAL_RATIO),
    replace=False
)
val_dataset = Subset(full_dataset, val_indices)
train_indices = list(set(range(len(full_dataset))) - set(val_indices))
train_dataset = Subset(full_dataset, train_indices)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    collate_fn=full_dataset.collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    collate_fn=full_dataset.collate_fn
)

print(f"number of training samples: {len(train_loader.dataset)}")
print(f"number of validation samples: {len(val_loader.dataset)}")
print(f"dataloaders created with batch size {BATCH_SIZE} and {NUM_WORKERS} workers")
print(f"=== Dataloaders Summary ===")
print(f"Train Loader: {len(train_loader)} batches")
print(f"Validation Loader: {len(val_loader)} batches")


model   = DETR(num_classes=1).to(device)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

print(f"=== Model Summary ===")
print(model)

number of training samples: 2738
number of validation samples: 684
dataloaders created with batch size 64 and 8 workers
=== Dataloaders Summary ===
Train Loader: 43 batches
Validation Loader: 11 batches


Using cache found in /Users/guyperets/.cache/torch/hub/pytorch_vision_v0.10.0


=== Model Summary ===
DETR(
  (backbone): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
      



# Dataset Sanity Check

In [3]:
# Debug: Test dataset creation and access
print(f"Train dataset length: {len(train_dataset)}")
print(f"Val dataset length: {len(val_dataset)}")

# Try to access the first item
try:
    first_item = train_dataset[0]
    print(first_item[1].items())
    print("Dataset access successful!")
except Exception as e:
    print(f"Error accessing first item: {e}")
    
# Check if datasets have any items
if len(train_dataset) == 0:
    print("ERROR: Train dataset is empty!")
if len(val_dataset) == 0:
    print("ERROR: Validation dataset is empty!")

Train dataset length: 2738
Val dataset length: 684
dict_items([('boxes', tensor([[0.8037, 0.8970, 0.1133, 0.0771],
        [0.1567, 0.5732, 0.1494, 0.0938],
        [0.8818, 0.7998, 0.0938, 0.0742],
        [0.8442, 0.9551, 0.1299, 0.0859],
        [0.3931, 0.4907, 0.1416, 0.2100],
        [0.0352, 0.0518, 0.0703, 0.1035],
        [0.7617, 0.1060, 0.1348, 0.0947],
        [0.2607, 0.4365, 0.2148, 0.1055],
        [0.9722, 0.6001, 0.0557, 0.1357],
        [0.4053, 0.6138, 0.2148, 0.1006],
        [0.6436, 0.5596, 0.1895, 0.1074],
        [0.7798, 0.6304, 0.1201, 0.1768],
        [0.2969, 0.2041, 0.1797, 0.1133],
        [0.6733, 0.2969, 0.1631, 0.1465],
        [0.3555, 0.9199, 0.0820, 0.1602],
        [0.9546, 0.1553, 0.0869, 0.2266],
        [0.3154, 0.3301, 0.2559, 0.0957],
        [0.0400, 0.6709, 0.0781, 0.0859],
        [0.1562, 0.4932, 0.1230, 0.0547],
        [0.0400, 0.4424, 0.0781, 0.1055]])), ('labels', tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])), ('

# Training Loop

In [4]:
ckpt_cb = ModelCheckpoint(  dirpath="checkpoints",
                            filename="detr-{epoch:02d}-{val_map:.3f}",
                            monitor="val_map",
                            mode="max",
                            save_top_k=3,
                            every_n_epochs=1,)
wandb_logger = WandbLogger(project="wheat-detection", log_model="all")
trainer = pl.Trainer(
    max_epochs=NUM_EPOCHS,
    accelerator="auto",
    callbacks=[ckpt_cb],
    logger=wandb_logger,
    enable_checkpointing= True,
    enable_progress_bar=True,
    enable_model_summary=True
)
    

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

[34m[1mwandb[0m: Currently logged in as: [33mguyperet[0m ([33mguyperet-ben-gurion-university-of-the-negev[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin



  | Name              | Type                    | Params | Mode 
----------------------------------------------------------------------
0 | backbone          | Sequential              | 23.5 M | train
1 | input_proj        | Conv2d                  | 524 K  | train
2 | position_encoding | PositionEmbeddingSine   | 0      | train
3 | transformer       | Transformer             | 17.4 M | train
4 | query_embed       | Embedding               | 25.6 K | train
5 | class_embed       | Linear                  | 514    | train
6 | bbox_embed        | MLP                     | 132 K  | train
7 | map_metric        | MeanAveragePrecision    | 0      | train
8 | criterion         | HungarianSetCriterion1C | 0      | train
----------------------------------------------------------------------
41.6 M    Trainable params
0         Non-trainable params
41.6 M    Total params
166.221   Total estimated model params size (MB)
313       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/guyperets/.pyenv/versions/wheat-env/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:420: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.
/Users/guyperets/.pyenv/versions/wheat-env/lib/python3.12/site-packages/pytorch_lightning/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 64. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/Users/guyperets/.pyenv/versions/wheat-env/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:420: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.
/Users/guyperets/.pyenv/versions/wheat-env/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py:310: The number of training batches (43) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower

Training: |          | 0/? [00:00<?, ?it/s]

wandb-core(73632) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Process Process-16:
Process Process-8:
Process Process-15:
Process Process-13:
Process Process-9:
Process Process-12:
Process Process-11:
Process Process-10:
Process Process-2:
Process Process-14:
Process Process-1:
Process Process-3:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/Users/guyperets/.pyenv/versions/3.12.10/lib/python3.12/multiprocessing/process.py", line 317, in _bootstrap
    util._exit_function()
  File "/Users/guyperets/.pyenv/versions/3.12.10/lib/python3.12/multiprocessing/util.py", line 363, 