In [1]:
import pandas as pd
import torch

from scripts.training import run_epoch, train_loop
from utils import move_data_to_device, move_dict_to_cpu, plot_image, plot_image_boxes
from PsfSimulator import PsfDataset
from models.subpix_rcnn import SubpixRCNN

import torch.nn as nn
from torchvision.ops import MultiScaleRoIAlign
from torchvision.ops import misc as misc_nn_ops
from torchvision.transforms._presets import ObjectDetection
from torchvision.models._api import register_model, Weights, WeightsEnum
from torchvision.models._meta import _COCO_CATEGORIES
from torchvision.models._utils import _ovewrite_value_param, handle_legacy_interface
from torchvision.models.resnet import resnet50, ResNet50_Weights
from torchvision.models.detection._utils import overwrite_eps
from torchvision.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
from torchvision.models.detection.faster_rcnn import _default_anchorgen, FasterRCNN, FastRCNNConvFCHead, RPNHead
from torchvision.models.detection.roi_heads import RoIHeads

from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.models.detection.generalized_rcnn import GeneralizedRCNN
import torch.nn.functional as F
import torch.optim as optim
import os
import datetime

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
# Density 1 = 40 spots per 64x64 image
# Density 0.05 = 2 spots per 64x64 image

In [2]:
# Instatiate the dataset
seed = None

num_spots_min = 2
num_spots_max = 80
sigma_mean= 1.0
sigma_std = 0.1
snr_min = 2
snr_max = 20
snr_std = 0.0
base_noise_min = 50
base_noise_max = 6000
use_gauss_noise = False
gauss_noise_std = 0.02
use_perlin_noise = False
perlin_min_max = (0.4, 0.6)
img_w = 64
img_h = 64

train_dataset = PsfDataset(seed, 1000, num_spots_min, num_spots_max, sigma_mean, sigma_std,
                      snr_min, snr_max, snr_std, base_noise_min, base_noise_max, use_gauss_noise,
                      gauss_noise_std, use_perlin_noise, perlin_min_max, img_w, img_h)


valid_dataset = PsfDataset(seed, 50, num_spots_min, num_spots_max, sigma_mean, sigma_std,
                      snr_min, snr_max, snr_std, base_noise_min, base_noise_max, use_gauss_noise,
                      gauss_noise_std, use_perlin_noise, perlin_min_max, img_w, img_h)


# Print the number of samples in the training and validation datasets
print(pd.Series({
    'Training dataset size:': len(train_dataset),
    'Validation dataset size:': len(valid_dataset)}))

data_loader_params = {'batch_size':8,    'collate_fn': lambda batch: tuple(zip(*batch)),}

training_loader = torch.utils.data.DataLoader(train_dataset, **data_loader_params)
validation_loader = torch.utils.data.DataLoader(valid_dataset, **data_loader_params)

Training dataset size:      1000
Validation dataset size:      50
dtype: int64


In [3]:
def generate_checkpoint_path(model_name, project_name):

  # 1. Define the project directory within Colab's content area
  folder_path = os.path.join(os.curdir, project_name)

  # 2. Create the directory if it doesn't exist
  os.makedirs(folder_path, exist_ok=True)

  # 3. Generate a timestamped subdirectory
  timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
  checkpoint_dir = os.path.join(folder_path, timestamp)
  os.makedirs(checkpoint_dir, exist_ok=True)

  # 4. Construct the checkpoint path
  checkpoint_path = os.path.join(checkpoint_dir, f"{model_name}.pth")

  return checkpoint_path

In [4]:
from torch.utils.tensorboard import SummaryWriter

In [5]:
writer = SummaryWriter('runs/psf_rcnn2')

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

backbone = resnet_fpn_backbone("resnet50", pretrained=True)
kwargs = {"nms_tresh": 0.1, "detections_per_img": None, "score_thresh": 0.8}

model = SubpixRCNN(backbone, num_classes=2, device=device, **kwargs)
model.to(device)
model.name = "testmodel"
optimizer = optim.Adam(model.parameters(), lr=0.0001)

checkpoint_path = generate_checkpoint_path("second_long_run", "subpix_rcnn_models")
num_epochs = 200

train_loop(model, training_loader, validation_loader, optimizer, device, num_epochs, checkpoint_path, writer)



Custom SubpixRoIHeads successfully initialized!


  signal = np.nanmax(array[y1:y2, x1:x2])
Train: 100%|██████████| 125/125 [01:00<00:00,  2.07it/s, loss=101, avg_loss=0.807, data_time=3784.257s, train_time=28.943s]
Eval: 100%|██████████| 7/7 [00:01<00:00,  4.20it/s, loss=2.98, avg_loss=0.426, data_time=5.762s, train_time=1.359s]
Epochs:   0%|          | 1/200 [01:02<3:26:31, 62.27s/it]

New best loss: 0.4261131158896855


Train: 100%|██████████| 125/125 [00:59<00:00,  2.09it/s, loss=55.5, avg_loss=0.444, data_time=3729.976s, train_time=28.092s]
Eval: 100%|██████████| 7/7 [00:01<00:00,  4.12it/s, loss=2.98, avg_loss=0.426, data_time=5.932s, train_time=1.342s]
Epochs:   1%|          | 2/200 [02:04<3:24:35, 62.00s/it]

New best loss: 0.4255546161106655


Train: 100%|██████████| 125/125 [01:00<00:00,  2.08it/s, loss=50.2, avg_loss=0.402, data_time=3726.797s, train_time=28.102s]
Eval: 100%|██████████| 7/7 [00:01<00:00,  4.21it/s, loss=2.47, avg_loss=0.353, data_time=5.787s, train_time=1.349s]
Epochs:   2%|▏         | 3/200 [03:05<3:23:22, 61.94s/it]

New best loss: 0.3530870016132082


Train: 100%|██████████| 125/125 [01:00<00:00,  2.08it/s, loss=47.3, avg_loss=0.378, data_time=3743.738s, train_time=28.197s]
Eval: 100%|██████████| 7/7 [00:01<00:00,  4.11it/s, loss=2.95, avg_loss=0.422, data_time=5.997s, train_time=1.380s]
Train: 100%|██████████| 125/125 [01:00<00:00,  2.08it/s, loss=44.8, avg_loss=0.358, data_time=3718.722s, train_time=28.152s]
Eval: 100%|██████████| 7/7 [00:01<00:00,  4.18it/s, loss=1.99, avg_loss=0.284, data_time=5.953s, train_time=1.354s]
Epochs:   2%|▎         | 5/200 [05:09<3:21:11, 61.91s/it]

New best loss: 0.28407070466450285


Train: 100%|██████████| 125/125 [00:59<00:00,  2.09it/s, loss=43.6, avg_loss=0.349, data_time=3718.258s, train_time=28.100s]
Eval: 100%|██████████| 7/7 [00:01<00:00,  4.12it/s, loss=2.33, avg_loss=0.333, data_time=5.966s, train_time=1.356s]
Train: 100%|██████████| 125/125 [00:59<00:00,  2.09it/s, loss=44.6, avg_loss=0.357, data_time=3723.613s, train_time=28.134s]
Eval: 100%|██████████| 7/7 [00:01<00:00,  4.15it/s, loss=2.45, avg_loss=0.35, data_time=5.905s, train_time=1.341s]
Train: 100%|██████████| 125/125 [00:59<00:00,  2.09it/s, loss=43.8, avg_loss=0.35, data_time=3724.188s, train_time=28.065s]
Eval: 100%|██████████| 7/7 [00:01<00:00,  4.10it/s, loss=2.86, avg_loss=0.409, data_time=6.055s, train_time=1.351s]
Train: 100%|██████████| 125/125 [01:00<00:00,  2.07it/s, loss=43.8, avg_loss=0.35, data_time=3754.183s, train_time=28.309s]
Eval: 100%|██████████| 7/7 [00:01<00:00,  4.18it/s, loss=2.21, avg_loss=0.316, data_time=5.894s, train_time=1.339s]
Train: 100%|██████████| 125/125 [00:59<

New best loss: 0.2462271558386939


Train: 100%|██████████| 125/125 [00:59<00:00,  2.09it/s, loss=39.7, avg_loss=0.317, data_time=3731.264s, train_time=28.091s]
Eval: 100%|██████████| 7/7 [00:01<00:00,  4.23it/s, loss=2.5, avg_loss=0.357, data_time=5.879s, train_time=1.301s]
Train: 100%|██████████| 125/125 [00:57<00:00,  2.16it/s, loss=38.5, avg_loss=0.308, data_time=3587.169s, train_time=27.135s]
Eval: 100%|██████████| 7/7 [00:01<00:00,  4.36it/s, loss=2.23, avg_loss=0.318, data_time=5.712s, train_time=1.290s]
Train: 100%|██████████| 125/125 [00:57<00:00,  2.16it/s, loss=40.7, avg_loss=0.325, data_time=3601.810s, train_time=27.134s]
Eval: 100%|██████████| 7/7 [00:01<00:00,  4.29it/s, loss=2.01, avg_loss=0.288, data_time=5.745s, train_time=1.317s]
Train: 100%|██████████| 125/125 [00:58<00:00,  2.15it/s, loss=38.1, avg_loss=0.305, data_time=3604.381s, train_time=27.297s]
Eval: 100%|██████████| 7/7 [00:01<00:00,  4.36it/s, loss=2.25, avg_loss=0.321, data_time=5.645s, train_time=1.288s]
Train: 100%|██████████| 125/125 [00:5

New best loss: 0.20575207258973802


Train: 100%|██████████| 125/125 [00:57<00:00,  2.16it/s, loss=35.8, avg_loss=0.286, data_time=3593.284s, train_time=27.134s]
Eval: 100%|██████████| 7/7 [00:01<00:00,  4.28it/s, loss=2.17, avg_loss=0.309, data_time=5.763s, train_time=1.300s]
Train: 100%|██████████| 125/125 [00:58<00:00,  2.15it/s, loss=36.1, avg_loss=0.289, data_time=3618.589s, train_time=27.333s]
Eval: 100%|██████████| 7/7 [00:01<00:00,  4.30it/s, loss=2.11, avg_loss=0.302, data_time=5.684s, train_time=1.313s]
Train: 100%|██████████| 125/125 [00:58<00:00,  2.15it/s, loss=38.1, avg_loss=0.305, data_time=3613.201s, train_time=27.439s]
Eval: 100%|██████████| 7/7 [00:01<00:00,  4.32it/s, loss=1.93, avg_loss=0.276, data_time=5.653s, train_time=1.296s]
Train: 100%|██████████| 125/125 [00:58<00:00,  2.15it/s, loss=36.7, avg_loss=0.293, data_time=3597.585s, train_time=27.332s]
Eval: 100%|██████████| 7/7 [00:01<00:00,  4.27it/s, loss=1.8, avg_loss=0.257, data_time=5.792s, train_time=1.317s]
Train: 100%|██████████| 125/125 [00:5

New best loss: 0.20275842717715672


Train: 100%|██████████| 125/125 [01:00<00:00,  2.06it/s, loss=34.5, avg_loss=0.276, data_time=3769.098s, train_time=28.399s]
Eval: 100%|██████████| 7/7 [00:01<00:00,  4.17it/s, loss=1.89, avg_loss=0.27, data_time=5.946s, train_time=1.358s]
Train: 100%|██████████| 125/125 [01:00<00:00,  2.06it/s, loss=35.2, avg_loss=0.282, data_time=3777.510s, train_time=28.470s]
Eval: 100%|██████████| 7/7 [00:01<00:00,  4.30it/s, loss=2.04, avg_loss=0.291, data_time=5.687s, train_time=1.334s]
Train: 100%|██████████| 125/125 [01:00<00:00,  2.08it/s, loss=35.6, avg_loss=0.285, data_time=3737.354s, train_time=28.155s]
Eval: 100%|██████████| 7/7 [00:01<00:00,  4.30it/s, loss=1.51, avg_loss=0.216, data_time=5.697s, train_time=1.320s]
Train: 100%|██████████| 125/125 [01:00<00:00,  2.06it/s, loss=33.9, avg_loss=0.271, data_time=3765.432s, train_time=28.392s]
Eval: 100%|██████████| 7/7 [00:01<00:00,  4.18it/s, loss=2.21, avg_loss=0.316, data_time=5.824s, train_time=1.348s]
Train: 100%|██████████| 125/125 [00:5

In [None]:
import matplotlib.pyplot as plt
from perlin_numpy import generate_perlin_noise_2d
