In [21]:
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input, decode_predictions
from tensorflow.keras.preprocessing import image
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import os
import torchvision.transforms as transforms
from scipy.optimize import linear_sum_assignment
from tqdm import tqdm
import torch.optim as optim
import random

In [15]:
ds = "dataset1/Military objects in military environments/Dataset"
# ds = f"military drones images"

models = f"models"
train_dataset = f"datasets/{ds}/train"
test_dataset = f"datasets/{ds}/test"
val_dataset = f"datasets/{ds}/valid"

In [8]:
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [9]:
# laod the model
model2 = torch.hub.load('facebookresearch/detr:main', 'detr_resnet50', pretrained=True)


Downloading: "https://github.com/facebookresearch/detr/zipball/main" to /Users/dania/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth" to /Users/dania/.cache/torch/hub/checkpoints/detr-r50-e632da11.pth
100.0%


In [None]:
# #Freeze the ResNet backbone
# for param in model2.backbone.parameters():
#     param.requires_grad = False

# # Freeze the Transformer
# for param in model2.transformer.parameters():
#     param.requires_grad = False

# print("ResNet and Transformer layers frozen.")

""" this block is optional, uncomment for faster training but worse results"""

In [11]:
# Print the architecture of the head layer (classification and box prediction heads)
print("Classification head architecture:")
print(model2.class_embed)

print("\nBox prediction head architecture:")
print(model2.bbox_embed)
"""
look at out_features, the value represents how many classes the model is supposed to predict

"""

Classification head architecture:
Linear(in_features=256, out_features=92, bias=True)

Box prediction head architecture:
MLP(
  (layers): ModuleList(
    (0-1): 2 x Linear(in_features=256, out_features=256, bias=True)
    (2): Linear(in_features=256, out_features=4, bias=True)
  )
)


'\nlook at out_features, the value represents how many classes the model is supposed to predict\n\n'

In [12]:
class DroneDataset(torch.utils.data.Dataset):
    def __init__(self, image_dir, label_dir, classes, transform=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.classes = classes
        self.transform = transform
        self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith(('.jpg', '.jpeg', '.png'))]

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_name = self.image_filenames[idx]
        img_path = os.path.join(self.image_dir, img_name)
        label_path = os.path.join(self.label_dir, os.path.splitext(img_name)[0] + '.txt') # Assuming label files have the same name as images with a .txt extension

        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        boxes = []
        labels = []
        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                for line in f.readlines():
                    class_id, center_x, center_y, width, height = map(float, line.strip().split())
                    # Convert YOLO format (center_x, center_y, width, height) to [x_min, y_min, x_max, y_max]
                    # Assuming coordinates are normalized (0 to 1)
                    x_min = center_x - width / 2
                    y_min = center_y - height / 2
                    x_max = center_x + width / 2
                    y_max = center_y + height / 2
                    boxes.append([x_min, y_min, x_max, y_max])
                    labels.append(int(class_id)) # Assuming class_id is an integer

        target = {}
        target["boxes"] = torch.as_tensor(boxes, dtype=torch.float32)
        target["labels"] = torch.as_tensor(labels, dtype=torch.int64)

        return image, target

In [13]:
classes = {
    0 : "Tank",
    1 : "drone",
    2 : "people",
    3 : "soldier"
}
#tank drone people soldier

In [16]:
# Re-define the classes dictionary


# Define image and label directory paths
train_image_dir = train_dataset + "/images"
train_label_dir = train_dataset + "/labels"
val_image_dir = val_dataset + "/images"
val_label_dir = val_dataset + "/labels"
test_image_dir = test_dataset + "/images"
test_label_dir = test_dataset + "/labels"

# Define a suitable transform
# For DETR, a common transform resizes the image and normalizes it
transform = transforms.Compose([
    transforms.Resize((320, 320)),
    transforms.ToTensor(),
    # Add more transforms if needed, e.g., resizing to a fixed size expected by DETR
    # transforms.Resize((800, 800)),
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Instantiate the DroneDataset for each set
train_drone_dataset = DroneDataset(train_image_dir, train_label_dir, classes, transform=transform)
val_drone_dataset = DroneDataset(val_image_dir, val_label_dir, classes, transform=transform)
test_drone_dataset = DroneDataset(test_image_dir, test_label_dir, classes, transform=transform)


# Instantiate DataLoader objects for each dataset
batch_size = 1  # You can adjust this based on your GPU memory
# Added num_workers for parallel data loading
train_dataloader = DataLoader(train_drone_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: tuple(zip(*x)), num_workers=0)
val_dataloader = DataLoader(val_drone_dataset, batch_size=batch_size, shuffle=False, collate_fn=lambda x: tuple(zip(*x)), num_workers=0)
test_dataloader = DataLoader(test_drone_dataset, batch_size=batch_size, shuffle=False, collate_fn=lambda x: tuple(zip(*x)), num_workers=0)

print("DataLoaders for train, validation, and test datasets instantiated with num_workers.")

DataLoaders for train, validation, and test datasets instantiated with num_workers.


In [17]:
# Get the number of classes from your defined classes dictionary
num_classes = len(classes) # This will be 11 (0-10)

# Modify the classification head (class_embed) to match the number of classes in your dataset
# The input features should remain the same as the output features of the transformer decoder
model2.class_embed = torch.nn.Linear(model2.class_embed.in_features, num_classes)

# Move the modified model to the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model2.to(device)

print(f"Model's classification head modified to output {model2.class_embed.out_features} classes.")

Model's classification head modified to output 4 classes.


In [18]:
model2.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))

DETR(
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-5): 6 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
          )
          (linear1): Linear(in_features=256, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=2048, out_features=256, bias=True)
          (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (decoder): TransformerDecoder(
      (layers): ModuleList(
        (0-5): 6 x TransformerDecoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=256, ou

In [19]:
class DETRLoss(nn.Module):
    """
    DETR Loss Function with Hungarian Matching

    This loss computes the optimal bipartite matching between predicted and ground truth objects,
    and then computes classification and bounding box regression losses.
    """

    def __init__(self, num_classes, matcher_cost_class=1, matcher_cost_bbox=5,
                 matcher_cost_giou=2, loss_ce=2, loss_bbox=2.5, loss_giou=2,
                 eos_coef=0.1):
        """
        Parameters:
        - num_classes: number of object categories
        - matcher_cost_class: relative weight of classification error in matching cost
        - matcher_cost_bbox: relative weight of L1 error of bounding box coordinates in matching
        - matcher_cost_giou: relative weight of giou loss of bounding box in matching
        - loss_ce: relative weight of classification loss
        - loss_bbox: relative weight of L1 bounding box loss
        - loss_giou: relative weight of giou bounding box loss
        - eos_coef: relative classification weight applied to the no-object category
        """
        super().__init__()
        self.num_classes = num_classes
        self.matcher_cost_class = matcher_cost_class
        self.matcher_cost_bbox = matcher_cost_bbox
        self.matcher_cost_giou = matcher_cost_giou
        self.loss_ce = loss_ce
        self.loss_bbox = loss_bbox
        self.loss_giou = loss_giou
        self.eos_coef = eos_coef

        # Build weight vector for classification loss
        empty_weight = torch.ones(self.num_classes + 1)
        empty_weight[-1] = self.eos_coef  # Background class
        self.register_buffer('empty_weight', empty_weight)

    def hungarian_matching(self, outputs, targets):
        batch_size, num_queries = outputs["pred_logits"].shape[:2]

        # Flatten to compute the cost matrices in a batch
        out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1)  # [batch_size * num_queries, num_classes + 1]
        out_bbox = outputs["pred_boxes"].flatten(0, 1)  # [batch_size * num_queries, 4]

        # Concatenate all target labels and boxes
        tgt_ids = torch.cat([v["labels"] for v in targets])
        tgt_bbox = torch.cat([v["boxes"] for v in targets])

        # Compute the classification cost
        cost_class = -out_prob[:, tgt_ids].log() # ------------------------------------------------------------ calculates the "how wrong the model's classification is" = measures confidence

        # Compute the L1 cost between boxes
        cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) # ---------------------------------------------------- calcs the "how wring the model's box location is compared to the ground truth" - Compute negative log-likelihood classification cost (higher for wrong predictions) = uses simple l1 distance, absolute difference between coordinates


        # Compute the GIoU cost between boxes
        cost_giou = -self.generalized_box_iou(
            self.box_cxcywh_to_xyxy(out_bbox),
            self.box_cxcywh_to_xyxy(tgt_bbox)
        ) # --------------------------------------------------------------------------------------------------- same as before, measures how wrong the box location is, but considering the spaces between the boxes - generalised intersction over union = gives sxore of -1 to +1

        # Final cost matrix
        C = (self.matcher_cost_bbox * cost_bbox +
             self.matcher_cost_class * cost_class +
             self.matcher_cost_giou * cost_giou)
        C = C.view(batch_size, num_queries, -1).cpu() # ------------------------------------------------------- combines all the three above costs into one single cost matrix

        sizes = [len(v["boxes"]) for v in targets]
        indices = []

        for i, c in enumerate(C.split(sizes, -1)):
            # Hungarian algorithm on the detached CPU tensor
            pred_indices, target_indices = linear_sum_assignment(c[i].detach().cpu().numpy()) # ---------------- looks at the entire cost matrix and finds the single best way to pair the predicted objects = weighted sum of three individual costs
            indices.append((torch.as_tensor(pred_indices, dtype=torch.int64),
                          torch.as_tensor(target_indices, dtype=torch.int64)))

        return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
                for i, j in indices]

    def loss_labels(self, outputs, targets, indices):
        """Classification loss (Cross Entropy)"""
        assert 'pred_logits' in outputs
        src_logits = outputs['pred_logits']

        idx = self._get_src_permutation_idx(indices)
        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
        target_classes = torch.full(src_logits.shape[:2], self.num_classes,
                                   dtype=torch.int64, device=src_logits.device)
        target_classes[idx] = target_classes_o
        # Ensure target_classes is on the same device as src_logits
        target_classes = target_classes.to(src_logits.device)

        # Ensure empty_weight is on the same device as src_logits
        loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight.to(src_logits.device))
        return loss_ce

    def loss_boxes(self, outputs, targets, indices):
        """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss"""
        assert 'pred_boxes' in outputs
        idx = self._get_src_permutation_idx(indices)
        src_boxes = outputs['pred_boxes'][idx]
        target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)

        loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
        loss_bbox = loss_bbox.sum() / len(target_boxes) if len(target_boxes) > 0 else torch.tensor(0.0, device=src_boxes.device)

        loss_giou = 1 - torch.diag(self.generalized_box_iou(
            self.box_cxcywh_to_xyxy(src_boxes),
            self.box_cxcywh_to_xyxy(target_boxes)))
        loss_giou = loss_giou.sum() / len(target_boxes) if len(target_boxes) > 0 else torch.tensor(0.0, device=src_boxes.device)

        return loss_bbox, loss_giou

    def _get_src_permutation_idx(self, indices):
        """Permute predictions following indices"""
        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
        src_idx = torch.cat([src for (src, _) in indices])
        return batch_idx, src_idx

    def box_cxcywh_to_xyxy(self, x):
        """Convert boxes from (cx, cy, w, h) to (x1, y1, x2, y2) format"""
        x_c, y_c, w, h = x.unbind(-1)
        b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
             (x_c + 0.5 * w), (y_c + 0.5 * h)]
        return torch.stack(b, dim=-1)

    def generalized_box_iou(self, boxes1, boxes2): # -------------------------------------------------------------------------------------------------------- explain the intersection over union formula -
        """
        Generalized IoU from https://giou.stanford.edu/
        The boxes should be in [x0, y0, x1, y1] format
        """
        # Ensure boxes are valid
        assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
        assert (boxes2[:, 2:] >= boxes2[:, :2]).all()

        # Compute intersection
        lt = torch.max(boxes1[:, None, :2], boxes2[:, :2])  # [N,M,2]
        rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])  # [N,M,2]

        wh = (rb - lt).clamp(min=0)  # [N,M,2]
        inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]

        # Compute union
        area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
        area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
        union = area1[:, None] + area2 - inter

        # Compute IoU
        iou = inter / union

        # Compute the area of the smallest enclosing box
        lti = torch.min(boxes1[:, None, :2], boxes2[:, :2])
        rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])

        whi = (rbi - lti).clamp(min=0)  # [N,M,2]
        areai = whi[:, :, 0] * whi[:, :, 1]

        return iou - (areai - union) / areai

    def forward(self, outputs, targets):
        """
        This performs the loss computation.

        Args:
            outputs: dict of tensors with keys:
                - pred_logits: Tensor of dim [batch_size, num_queries, num_classes + 1]
                - pred_boxes: Tensor of dim [batch_size, num_queries, 4] in cxcywh format
            targets: list of dicts, such that len(targets) == batch_size.
                Each dict should contain:
                - labels: Tensor of dim [num_objects] containing the class labels
                - boxes: Tensor of dim [num_objects, 4] containing the boxes in cx,cy,w,h format

        Returns:
            dict: A dictionary containing the losses
        """
        # Retrieve the matching between the outputs of the model and the targets
        indices = self.hungarian_matching(outputs, targets)

        # Compute all the losses
        loss_ce = self.loss_labels(outputs, targets, indices)
        loss_bbox, loss_giou = self.loss_boxes(outputs, targets, indices)

        # Combine losses
        losses = {
            'loss_ce': loss_ce * self.loss_ce,
            'loss_bbox': loss_bbox * self.loss_bbox,
            'loss_giou': loss_giou * self.loss_giou,
        }

        # Total loss
        losses['loss_total'] = sum(losses.values())

        return losses
# criterion = DETRLoss(model2.class_embed.out_features - 1) # really just 11 output dims
criterion = DETRLoss(
    num_classes=model2.class_embed.out_features - 1, # 11
    matcher_cost_class=1,
    matcher_cost_bbox=5,
    matcher_cost_giou=2,
    loss_ce=1,
    loss_bbox=5,
    loss_giou=2,
    eos_coef=0.1
)

In [22]:
num_epochs = 300
# Move the model to the device (GPU if available)
model2.to(device)

# Define a directory to save checkpoints
checkpoint_dir = models
os.makedirs(checkpoint_dir, exist_ok=True) # just make the dir if not there

start_epoch = 0
optimizer = None  # Initialize optimizer to None

# Check for existing checkpoints to resume training
checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.pth')]
if checkpoints:
    print("Found existing checkpoints. Resuming training.")
    # Find the latest checkpoint
    latest_checkpoint = max([os.path.join(checkpoint_dir, f) for f in checkpoints], key=os.path.getctime) # ---- if folder has .pth file, use the one trained latest
    print(f"Loading checkpoint from: {latest_checkpoint}")
    model2.load_state_dict(torch.load(latest_checkpoint, map_location=torch.device('cpu'))) # ------------------ load the thing
    start_epoch = int(latest_checkpoint.split('_')[-1].split('.')[0]) # ---------------------------------------- set the epoch to the last trained epoch
    print(f"Resuming from epoch {start_epoch + 1}")

    optimizer = optim.AdamW(model2.parameters(), lr=1e-5)

else:
    print("No existing checkpoints found. Starting training from scratch.")
    optimizer = optim.AdamW(model2.parameters(), lr=1e-5)


No existing checkpoints found. Starting training from scratch.


In [23]:
torch.cuda.empty_cache()

#validation call

def validate_model(model, dataloader, loss_fn, device, checkpoint_dir):
    # Find the latest checkpoint
    checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.pth')]
    if not checkpoints:
        print("No checkpoints found in the specified directory.")
        return None # Or raise an error

    latest_checkpoint = max([os.path.join(checkpoint_dir, f) for f in checkpoints], key=os.path.getctime)
    print(f"Loading model from: {latest_checkpoint}")
    model.load_state_dict(torch.load(latest_checkpoint))

    model.eval() # Set the model to evaluation mode
    running_loss = 0.0
    with torch.no_grad(): # Disable gradient calculation
        for images, targets in tqdm(dataloader, desc="Validation"):
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            outputs = model(images)
            # Use the correct loss function (DETRLoss)
            loss = loss_fn(outputs, targets)['loss_total']


            running_loss += loss.item()

    avg_loss = running_loss / len(dataloader)
    return avg_loss

# Define the checkpoint directory
checkpoint_dir = models # Use the same directory as in training

# # Ensure the model is on the correct device before loading the state_dict
# model2.to(device)

# # Run validation
# # Make sure to use the correct loss function (DETRLoss)
# val_loss = validate_model(model2, val_dataloader, criterion, device, checkpoint_dir)
# if val_loss is not None:
#     print(f"Validation Loss: {val_loss:.4f}")

In [24]:
import os
import logging
from datetime import datetime

# Set up logging configuration
log_file = os.path.join(checkpoint_dir, 'log.txt')
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(log_file),
        logging.StreamHandler()  # This will still print to console
    ]
)

# Log training start information
logging.info("="*50)
logging.info("TRAINING SESSION STARTED")
logging.info(f"Start epoch: {start_epoch}")
logging.info(f"Total epochs: {num_epochs}")
logging.info(f"Device: {device}")
logging.info("="*50)

for epoch in range(start_epoch, num_epochs): # start from the latest epoch and end with the end epoch(300)
    try: # -------------------------------------------------------------------------------------------------------------------------------------------------------------------- this is to handle errors
        logging.info(f"Starting Epoch {epoch+1}/{num_epochs}")
        
        model2.train() # ------------------------------------------------------------------------------------------------------------------------------------------------------ puts the model into training mode
        running_loss = 0.0 # -------------------------------------------------------------------------------------------------------------------------------------------------- sets "cumulative loss" to 0
        batch_count = 0
        
        # Wrap the DataLoader with tqdm for a progress bar
        for images, targets in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            # Move images and targets to the device (GPU if available)
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            outputs = model2(images)


            
            losses = criterion(outputs, targets) # Calculate total loss using detrloss
            loss = losses['loss_total']
           # print("---------------------------------------------------------------------------------------------------------------------------")
           #print(f"the loss is {loss}")
           # print("---------------------------------------------------------------------------------------------------------------------------")
           # print(f"the losses is {losses}") 
           # print("---------------------------------------------------------------------------------------------------------------------------")
           # print(f"the loss total is {losses['loss_total']}") 
           # print(f"with type {type(losses['loss_total'])}")
           # print(losses['loss_total'].requires_grad)

           # print("---------------------------------------------------------------------------------------------------------------------------")
           # print("---------------------------------------------------------------------------------------------------------------------------")
           # print("---------------------------------------------------------------------------------------------------------------------------")
            
            
            optimizer.zero_grad()
           # print("optimising done")
            loss.backward()
           # print("backwards done")
            optimizer.step()
           # print("stepping done")
            running_loss += loss.item()
            batch_count += 1
        
        # Calculate average training loss for this epoch
        avg_train_loss = running_loss / len(train_dataloader)
        
        # Log training results
        logging.info(f"Epoch [{epoch+1}/{num_epochs}] - Training Loss: {avg_train_loss:.4f}")
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_train_loss:.4f}")
        
        # Save the model checkpoint after each successful epoch
        checkpoint_path = os.path.join(checkpoint_dir, f"model_epoch_{epoch+1}.pth")
        torch.save(model2.state_dict(), checkpoint_path)
        logging.info(f"Model checkpoint saved to {checkpoint_path}")
        print(f"Model saved to {checkpoint_path}")
        
        # Run validation and log results
        val_loss = validate_model(model2, val_dataloader, criterion, device, checkpoint_dir)
        if val_loss is not None:
            logging.info(f"Epoch [{epoch+1}/{num_epochs}] - Validation Loss: {val_loss:.4f}")
            print(f"Validation Loss: {val_loss:.4f}")
        else:
            logging.warning(f"Epoch [{epoch+1}/{num_epochs}] - Validation returned None")
        
        # Log epoch completion
        logging.info(f"Epoch {epoch+1} completed successfully")
        logging.info("-" * 30)
        
    except Exception as e:
        error_msg = f"An error occurred during epoch {epoch+1}: {e}"
        logging.error(error_msg)
        print(error_msg)
        
        logging.info("Attempting to load the last saved checkpoint and stopping training.")
        print("Attempting to load the last saved checkpoint and stopping training.")
        
        # Find the last saved checkpoint before the error occurred
        checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.pth')]
        if checkpoints:
            last_checkpoint = max([os.path.join(checkpoint_dir, f) for f in checkpoints], key=os.path.getctime)
            logging.info(f"Loading checkpoint from: {last_checkpoint}")
            print(f"Loading checkpoint from: {last_checkpoint}")
            try:
                model2.load_state_dict(torch.load(last_checkpoint))
                logging.info("Successfully loaded last checkpoint.")
                print("Successfully loaded last checkpoint.")
            except Exception as load_error:
                error_msg = f"Error loading checkpoint: {load_error}"
                logging.error(error_msg)
                print(error_msg)
        else:
            logging.warning("No checkpoints found to load.")
            print("No checkpoints found to load.")
        
        logging.info("Training stopped due to error.")
        break # Stop the training loop after encountering an error

# Log training completion
logging.info("="*50)
logging.info("TRAINING SESSION ENDED")
logging.info(f"Final epoch reached: {epoch+1}")
logging.info("="*50)

# After training (or interruption), move the model back to CPU if needed for inference or saving
# model2.to('cpu')
# check epoch 27 and 74

2025-10-30 21:24:16,320 - INFO - TRAINING SESSION STARTED
2025-10-30 21:24:16,325 - INFO - Start epoch: 0
2025-10-30 21:24:16,327 - INFO - Total epochs: 300
2025-10-30 21:24:16,329 - INFO - Device: cpu
2025-10-30 21:24:16,334 - INFO - Starting Epoch 1/300
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Epoch 1/300:   0%|          | 4/5591 [00:11<4:21:13,  2.81s/it]


KeyboardInterrupt: 