In [1]:
import pandas as pd
import torch

from scripts.training import run_epoch, train_loop
from utils import move_data_to_device, move_dict_to_cpu, plot_image, plot_image_boxes
from PsfSimulator import PsfDataset
from models.subpix_rcnn import SubpixRCNN

import torch.nn as nn
from torchvision.ops import MultiScaleRoIAlign
from torchvision.ops import misc as misc_nn_ops
from torchvision.transforms._presets import ObjectDetection
from torchvision.models._api import register_model, Weights, WeightsEnum
from torchvision.models._meta import _COCO_CATEGORIES
from torchvision.models._utils import _ovewrite_value_param, handle_legacy_interface
from torchvision.models.resnet import resnet50, ResNet50_Weights
from torchvision.models.detection._utils import overwrite_eps
from torchvision.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
from torchvision.models.detection.faster_rcnn import _default_anchorgen, FasterRCNN, FastRCNNConvFCHead, RPNHead
from torchvision.models.detection.roi_heads import RoIHeads

from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.models.detection.generalized_rcnn import GeneralizedRCNN
import torch.nn.functional as F
import torch.optim as optim
import os
import datetime

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
# Density 1 = 40 spots per 64x64 image
# Density 0.05 = 2 spots per 64x64 image

In [2]:
# Instatiate the dataset
seed = None

num_spots_min = 2
num_spots_max = 80
sigma_mean= 1.0
sigma_std = 0.1
snr_min = 2
snr_max = 10
snr_std = 0.0
base_noise_min = 20
base_noise_max = 150
use_gauss_noise = False
gauss_noise_std = 0.02
use_perlin_noise = False
perlin_min_max = (0.4, 0.6)
img_w = 64
img_h = 64

train_dataset = PsfDataset(seed, 1000, num_spots_min, num_spots_max, sigma_mean, sigma_std,
                      snr_min, snr_max, snr_std, base_noise_min, base_noise_max, use_gauss_noise,
                      gauss_noise_std, use_perlin_noise, perlin_min_max, img_w, img_h)


valid_dataset = PsfDataset(seed, 50, num_spots_min, num_spots_max, sigma_mean, sigma_std,
                      snr_min, snr_max, snr_std, base_noise_min, base_noise_max, use_gauss_noise,
                      gauss_noise_std, use_perlin_noise, perlin_min_max, img_w, img_h)


# Print the number of samples in the training and validation datasets
print(pd.Series({
    'Training dataset size:': len(train_dataset),
    'Validation dataset size:': len(valid_dataset)}))

data_loader_params = {'batch_size':4,    'collate_fn': lambda batch: tuple(zip(*batch)),}

training_loader = torch.utils.data.DataLoader(train_dataset, **data_loader_params)
validation_loader = torch.utils.data.DataLoader(valid_dataset, **data_loader_params)

Training dataset size:      1000
Validation dataset size:      50
dtype: int64


In [3]:
def generate_checkpoint_path(model_name, project_name):

  # 1. Define the project directory within Colab's content area
  folder_path = os.path.join(os.curdir, project_name)

  # 2. Create the directory if it doesn't exist
  os.makedirs(folder_path, exist_ok=True)

  # 3. Generate a timestamped subdirectory
  timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
  checkpoint_dir = os.path.join(folder_path, timestamp)
  os.makedirs(checkpoint_dir, exist_ok=True)

  # 4. Construct the checkpoint path
  checkpoint_path = os.path.join(checkpoint_dir, f"{model_name}.pth")

  return checkpoint_path

In [4]:
from torch.utils.tensorboard import SummaryWriter

In [5]:
writer = SummaryWriter('runs/psf_rcnn')

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

backbone = resnet_fpn_backbone("resnet50", pretrained=True)
kwargs = {"nms_tresh": 0.1, "detections_per_img": None, "score_thresh": 0.7}

model = SubpixRCNN(backbone, num_classes=2, device=device, **kwargs)
model.to(device)
model.name = "testmodel"
optimizer = optim.Adam(model.parameters(), lr=0.0001)

checkpoint_path = generate_checkpoint_path("first_long_run", "subpix_rcnn_models")
num_epochs = 200

train_loop(model, training_loader, validation_loader, optimizer, device, num_epochs, checkpoint_path, writer)



Custom SubpixRoIHeads successfully initialized!


Train: 100%|██████████| 250/250 [00:58<00:00,  4.25it/s, loss=179, avg_loss=0.717, data_time=7367.185s, train_time=30.540s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.09it/s, loss=7.92, avg_loss=0.609, data_time=9.961s, train_time=1.401s]
Epochs:   0%|          | 1/200 [01:00<3:20:46, 60.54s/it]

New best loss: 0.6089555185574752


Train: 100%|██████████| 250/250 [00:57<00:00,  4.36it/s, loss=137, avg_loss=0.547, data_time=7182.089s, train_time=29.495s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.28it/s, loss=7.1, avg_loss=0.546, data_time=9.829s, train_time=1.372s]
Epochs:   1%|          | 2/200 [01:59<3:17:02, 59.71s/it]

New best loss: 0.5459109338430258


Train: 100%|██████████| 250/250 [00:56<00:00,  4.41it/s, loss=130, avg_loss=0.52, data_time=7090.235s, train_time=29.054s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.49it/s, loss=6.22, avg_loss=0.478, data_time=9.713s, train_time=1.335s]
Epochs:   2%|▏         | 3/200 [02:58<3:14:04, 59.11s/it]

New best loss: 0.47816068163284886


Train: 100%|██████████| 250/250 [00:56<00:00,  4.44it/s, loss=126, avg_loss=0.504, data_time=7022.008s, train_time=28.967s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.40it/s, loss=6.6, avg_loss=0.508, data_time=9.919s, train_time=1.337s]
Train: 100%|██████████| 250/250 [00:56<00:00,  4.42it/s, loss=130, avg_loss=0.521, data_time=7037.343s, train_time=29.088s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.35it/s, loss=7.26, avg_loss=0.558, data_time=9.832s, train_time=1.340s]
Train: 100%|██████████| 250/250 [00:56<00:00,  4.42it/s, loss=125, avg_loss=0.5, data_time=7056.103s, train_time=29.071s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.45it/s, loss=6.28, avg_loss=0.483, data_time=9.829s, train_time=1.333s]
Train: 100%|██████████| 250/250 [00:56<00:00,  4.43it/s, loss=126, avg_loss=0.506, data_time=7032.583s, train_time=29.048s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.48it/s, loss=6.88, avg_loss=0.529, data_time=9.723s, train_time=1.333s]
Train: 100%|██████████| 250/250 [00

New best loss: 0.46316010447648853


Train: 100%|██████████| 250/250 [00:56<00:00,  4.42it/s, loss=122, avg_loss=0.486, data_time=7041.988s, train_time=29.191s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.48it/s, loss=5.93, avg_loss=0.456, data_time=9.765s, train_time=1.340s]
Epochs:   6%|▋         | 13/200 [12:39<3:01:45, 58.32s/it]

New best loss: 0.45614994718478274


Train: 100%|██████████| 250/250 [00:56<00:00,  4.42it/s, loss=121, avg_loss=0.484, data_time=7038.250s, train_time=29.182s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.37it/s, loss=6.45, avg_loss=0.496, data_time=9.912s, train_time=1.345s]
Train: 100%|██████████| 250/250 [00:56<00:00,  4.42it/s, loss=120, avg_loss=0.479, data_time=7035.664s, train_time=29.253s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.40it/s, loss=6.43, avg_loss=0.494, data_time=9.762s, train_time=1.348s]
Train: 100%|██████████| 250/250 [00:56<00:00,  4.41it/s, loss=123, avg_loss=0.492, data_time=7059.272s, train_time=29.337s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.53it/s, loss=6.69, avg_loss=0.515, data_time=9.549s, train_time=1.323s]
Train: 100%|██████████| 250/250 [00:56<00:00,  4.42it/s, loss=123, avg_loss=0.492, data_time=7043.463s, train_time=29.121s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.40it/s, loss=6.9, avg_loss=0.531, data_time=9.723s, train_time=1.350s]
Train: 100%|██████████| 250/250 [

New best loss: 0.4401691578901731


Train: 100%|██████████| 250/250 [00:56<00:00,  4.44it/s, loss=117, avg_loss=0.47, data_time=7019.637s, train_time=29.055s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.34it/s, loss=6.77, avg_loss=0.521, data_time=9.905s, train_time=1.340s]
Train: 100%|██████████| 250/250 [00:56<00:00,  4.42it/s, loss=118, avg_loss=0.473, data_time=7035.417s, train_time=29.169s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.52it/s, loss=5.46, avg_loss=0.42, data_time=9.714s, train_time=1.331s]
Epochs:  10%|█         | 20/200 [19:26<2:54:26, 58.15s/it]

New best loss: 0.4203055478059329


Train: 100%|██████████| 250/250 [00:56<00:00,  4.43it/s, loss=120, avg_loss=0.478, data_time=7025.459s, train_time=29.100s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.43it/s, loss=5.85, avg_loss=0.45, data_time=9.672s, train_time=1.340s]
Train: 100%|██████████| 250/250 [00:56<00:00,  4.43it/s, loss=121, avg_loss=0.483, data_time=7035.442s, train_time=29.038s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.56it/s, loss=5.75, avg_loss=0.442, data_time=9.522s, train_time=1.336s]
Train: 100%|██████████| 250/250 [00:56<00:00,  4.43it/s, loss=119, avg_loss=0.476, data_time=7025.777s, train_time=29.080s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.46it/s, loss=6.36, avg_loss=0.489, data_time=9.731s, train_time=1.333s]
Train: 100%|██████████| 250/250 [00:56<00:00,  4.43it/s, loss=120, avg_loss=0.481, data_time=7025.735s, train_time=29.080s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.52it/s, loss=7.11, avg_loss=0.547, data_time=9.651s, train_time=1.334s]
Train: 100%|██████████| 250/250 [

New best loss: 0.4062681370056592


Train: 100%|██████████| 250/250 [00:56<00:00,  4.42it/s, loss=115, avg_loss=0.461, data_time=7030.699s, train_time=29.096s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.37it/s, loss=5.67, avg_loss=0.436, data_time=9.804s, train_time=1.343s]
Train: 100%|██████████| 250/250 [00:56<00:00,  4.43it/s, loss=118, avg_loss=0.471, data_time=7039.698s, train_time=29.025s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.25it/s, loss=5.58, avg_loss=0.43, data_time=10.020s, train_time=1.359s]
Train: 100%|██████████| 250/250 [00:56<00:00,  4.41it/s, loss=117, avg_loss=0.468, data_time=7069.634s, train_time=29.219s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.49it/s, loss=6.73, avg_loss=0.518, data_time=9.661s, train_time=1.324s]
Train: 100%|██████████| 250/250 [00:56<00:00,  4.43it/s, loss=114, avg_loss=0.457, data_time=7027.707s, train_time=29.098s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.62it/s, loss=5.84, avg_loss=0.449, data_time=9.546s, train_time=1.312s]
Train: 100%|██████████| 250/250 

New best loss: 0.40345531472792995


Train: 100%|██████████| 250/250 [00:56<00:00,  4.43it/s, loss=114, avg_loss=0.456, data_time=7025.141s, train_time=28.978s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.51it/s, loss=5.67, avg_loss=0.436, data_time=9.669s, train_time=1.325s]
Train: 100%|██████████| 250/250 [00:56<00:00,  4.44it/s, loss=115, avg_loss=0.461, data_time=7023.995s, train_time=28.969s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.42it/s, loss=6.67, avg_loss=0.513, data_time=9.720s, train_time=1.331s]
Train: 100%|██████████| 250/250 [00:56<00:00,  4.45it/s, loss=114, avg_loss=0.455, data_time=7003.504s, train_time=28.912s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.64it/s, loss=5.69, avg_loss=0.438, data_time=9.616s, train_time=1.313s]
Train: 100%|██████████| 250/250 [00:56<00:00,  4.44it/s, loss=115, avg_loss=0.46, data_time=7024.180s, train_time=28.881s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.52it/s, loss=6.34, avg_loss=0.488, data_time=9.671s, train_time=1.324s]
Train: 100%|██████████| 250/250 [

New best loss: 0.3984619562442486


Train: 100%|██████████| 250/250 [00:56<00:00,  4.45it/s, loss=113, avg_loss=0.452, data_time=7006.533s, train_time=28.894s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.50it/s, loss=6.15, avg_loss=0.473, data_time=9.829s, train_time=1.326s]
Train: 100%|██████████| 250/250 [00:56<00:00,  4.44it/s, loss=117, avg_loss=0.466, data_time=7005.810s, train_time=28.826s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.52it/s, loss=6.05, avg_loss=0.466, data_time=9.649s, train_time=1.320s]
Train: 100%|██████████| 250/250 [00:56<00:00,  4.44it/s, loss=113, avg_loss=0.451, data_time=7013.762s, train_time=28.940s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.41it/s, loss=6.08, avg_loss=0.468, data_time=9.888s, train_time=1.336s]
Train: 100%|██████████| 250/250 [00:56<00:00,  4.45it/s, loss=113, avg_loss=0.453, data_time=7003.158s, train_time=28.810s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.47it/s, loss=5.88, avg_loss=0.452, data_time=9.749s, train_time=1.323s]
Train: 100%|██████████| 250/250 

New best loss: 0.39557945728302


Train: 100%|██████████| 250/250 [00:56<00:00,  4.45it/s, loss=111, avg_loss=0.444, data_time=6987.524s, train_time=28.912s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.46it/s, loss=6.38, avg_loss=0.49, data_time=9.851s, train_time=1.326s]
Train: 100%|██████████| 250/250 [00:56<00:00,  4.44it/s, loss=115, avg_loss=0.458, data_time=7017.265s, train_time=29.032s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.54it/s, loss=5.23, avg_loss=0.402, data_time=9.768s, train_time=1.317s]
Train: 100%|██████████| 250/250 [00:56<00:00,  4.44it/s, loss=114, avg_loss=0.456, data_time=7014.661s, train_time=28.928s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.55it/s, loss=5.79, avg_loss=0.445, data_time=9.612s, train_time=1.320s]
Train: 100%|██████████| 250/250 [00:56<00:00,  4.43it/s, loss=110, avg_loss=0.439, data_time=7013.501s, train_time=28.980s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.58it/s, loss=5.78, avg_loss=0.445, data_time=9.608s, train_time=1.320s]
Train: 100%|██████████| 250/250 [

New best loss: 0.39364294937023747


Train: 100%|██████████| 250/250 [00:56<00:00,  4.46it/s, loss=110, avg_loss=0.44, data_time=6985.688s, train_time=28.867s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.56it/s, loss=5.65, avg_loss=0.434, data_time=9.699s, train_time=1.305s]
Train: 100%|██████████| 250/250 [00:56<00:00,  4.46it/s, loss=110, avg_loss=0.438, data_time=6987.522s, train_time=28.749s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.57it/s, loss=5.64, avg_loss=0.434, data_time=9.698s, train_time=1.314s]
Train: 100%|██████████| 250/250 [00:56<00:00,  4.43it/s, loss=111, avg_loss=0.444, data_time=7028.034s, train_time=29.074s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.36it/s, loss=6.3, avg_loss=0.484, data_time=9.910s, train_time=1.341s]
Train: 100%|██████████| 250/250 [00:56<00:00,  4.43it/s, loss=109, avg_loss=0.435, data_time=7021.499s, train_time=29.140s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.40it/s, loss=5.86, avg_loss=0.451, data_time=9.794s, train_time=1.336s]
Train: 100%|██████████| 250/250 [0

New best loss: 0.37828349608641404


Train: 100%|██████████| 250/250 [00:56<00:00,  4.46it/s, loss=108, avg_loss=0.433, data_time=6976.714s, train_time=28.916s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.55it/s, loss=6.06, avg_loss=0.467, data_time=9.543s, train_time=1.317s]
Train: 100%|██████████| 250/250 [00:56<00:00,  4.46it/s, loss=110, avg_loss=0.442, data_time=6983.363s, train_time=28.973s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.46it/s, loss=5.79, avg_loss=0.446, data_time=9.675s, train_time=1.338s]
Train: 100%|██████████| 250/250 [00:56<00:00,  4.45it/s, loss=109, avg_loss=0.437, data_time=6995.585s, train_time=28.967s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.66it/s, loss=5.3, avg_loss=0.408, data_time=9.447s, train_time=1.312s]
Train: 100%|██████████| 250/250 [00:56<00:00,  4.44it/s, loss=112, avg_loss=0.448, data_time=6997.319s, train_time=28.947s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.46it/s, loss=6.45, avg_loss=0.496, data_time=9.695s, train_time=1.330s]
Train: 100%|██████████| 250/250 [

New best loss: 0.37544962534537685


Train: 100%|██████████| 250/250 [00:56<00:00,  4.45it/s, loss=108, avg_loss=0.431, data_time=7000.186s, train_time=29.066s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.19it/s, loss=6.07, avg_loss=0.467, data_time=10.064s, train_time=1.365s]
Train: 100%|██████████| 250/250 [00:56<00:00,  4.45it/s, loss=109, avg_loss=0.437, data_time=7013.667s, train_time=28.979s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.65it/s, loss=5.64, avg_loss=0.434, data_time=9.527s, train_time=1.311s]
Train: 100%|██████████| 250/250 [00:56<00:00,  4.44it/s, loss=109, avg_loss=0.435, data_time=7021.297s, train_time=29.022s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.65it/s, loss=5.16, avg_loss=0.397, data_time=9.524s, train_time=1.317s]
Train: 100%|██████████| 250/250 [00:56<00:00,  4.44it/s, loss=109, avg_loss=0.438, data_time=7000.848s, train_time=29.006s]
Eval: 100%|██████████| 13/13 [00:01<00:00,  8.48it/s, loss=5.72, avg_loss=0.44, data_time=9.780s, train_time=1.330s]
Train: 100%|██████████| 250/250 

In [None]:
import matplotlib.pyplot as plt
from perlin_numpy import generate_perlin_noise_2d
