# Coronagraph inpainting model

## Data downloading and preprocessing

!python3 data.py

In [1]:
from utils import get_default_device
device = get_default_device()
print(device)

cuda


In [2]:
import torch
torch.cuda.empty_cache()

## Creating dataset

In [3]:
batch_size = 1

In [4]:
from torch.utils.data import random_split, DataLoader
from utils import DeviceDataLoader
from data import CoronagraphDataset

dataset = CoronagraphDataset('c3')
# 0.8 - 0.2
train_len = round(0.8*len(dataset))
val_len = len(dataset) - train_len

#random split
train_ds, val_ds = random_split(dataset, [train_len, val_len])

train_dl = DeviceDataLoader(DataLoader(train_ds, batch_size, True, num_workers = 4, pin_memory=True), device)
val_dl = DeviceDataLoader(DataLoader(val_ds, batch_size, True, num_workers = 4, pin_memory=True), device)

  "class": algorithms.Blowfish,


In [5]:
from model import UNetArchitecture, UNetArchitectureDeluxe, SmallUNet
from utils import to_device
from loss import NewInpaintingLoss, UNetLoss, OldInpaintingLoss

# model = to_device(SingleLayer('c3', InpaintingLoss([6,6, 0.05, 10, 10]), (7,5,3)), device)

# model = to_device(UNetArchitecture('c3_UNet', UNetLoss([4, 6, 0.05, 110, 120])), device)

model = to_device(SmallUNet(f'c3_Small_Unet_{0.2}_{0.8}_{0.05}_{1}_{1.2}', NewInpaintingLoss([1., 0.8, 0.05, 100, 80]), 2), device)

1. c3_Small_Unet_0.2_0.8_0.05_1_1.2_10.pt
2. c3_Small_Unet_0.2_0.8_0.05_1_1.2_5.pt
3. c3_Small_Unet_0.2_0.8_0.05_1_1.2_0.pt


In [6]:
import torch
from utils import to_device
sample_tensor = (torch.randn(1,1,1024,1024), torch.ones(1,1,1024,1024))
sample_tensor = to_device(sample_tensor, device)

In [7]:
import torch.nn.functional as F
model.spe_act = lambda x: F.sigmoid(x)

In [8]:
from torch.optim.lr_scheduler import OneCycleLR
from torch.optim import Adam, SGD, RMSprop

model.fit(
    train_dl,
    val_dl,
    epochs = 200,
    lr = 1e-5,
    batch_size = batch_size,
    weight_decay = 0,
    grad_clip = 1e-4,
    opt_func = Adam,
    lr_sched = OneCycleLR,
    saving_div = 5,
    graph = True,
    sample_input = sample_tensor
)

AttributeError: 'NoneType' object has no attribute 'param_groups'

In [9]:
config = torch.load('c3_Small_Unet_0.2_0.8_0.05_1_1.2/models/c3_Small_Unet_0.2_0.8_0.05_1_1.2_5.pt')

In [28]:
from torch.optim import Adam


hola = Adam(model.parameters())

In [29]:
hola

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.001
    maximize: False
    weight_decay: 0
)

In [30]:
config['optimizer_state_dict']

{'state': {38: {'step': tensor(7896.),
   'exp_avg': tensor([[[[-4.9442e-06,  1.3060e-05,  1.3655e-05,  ...,  2.8693e-06,
               1.4038e-05,  1.4994e-05],
             [-2.3716e-05,  1.5912e-05,  1.4570e-05,  ..., -5.7687e-06,
               5.2068e-06,  1.0633e-05],
             [-2.5163e-05,  1.6941e-05,  1.6413e-05,  ..., -5.1257e-06,
               6.0304e-06,  5.9051e-06],
             ...,
             [-2.6610e-05,  1.3199e-05,  1.3243e-05,  ..., -5.2459e-06,
               6.0760e-06,  3.5997e-06],
             [-2.6607e-05,  1.0738e-05,  9.1095e-06,  ..., -5.4770e-06,
              -1.0052e-06, -6.3546e-06],
             [-1.0314e-05,  1.4907e-05,  2.3609e-05,  ...,  9.0457e-06,
               1.1377e-05,  1.0852e-05]]],
   
   
           [[[ 8.5285e-06,  2.0647e-05,  2.7467e-05,  ...,  1.4922e-05,
               2.9038e-05,  1.2852e-05],
             [-1.2609e-05, -2.3362e-05, -2.3105e-05,  ...,  9.7723e-06,
               4.3152e-06,  1.5961e-05],
             [-1.5

In [32]:
hola = Adam(model.parameters())
hola.load_state_dict(config['optimizer_state_dict'])

In [33]:
hola

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.0002
    maximize: False
    weight_decay: 0
)

In [35]:
list(config.keys())

['name',
 'device',
 'epoch',
 'global_step_train',
 'global_step_val',
 'optimizer_state_dict',
 'scheduler_state_dict',
 'model_state_dict']