In [1]:
%load_ext autoreload
%autoreload 2
from types import SimpleNamespace
import json
import os
import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from models.efficient_unet import AbstractUNet
from dataset import CatDataset
from inpaint_tools import read_file_list
from skimage import io
import numpy as np
import pathlib

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
args = {"unet": {"block": "ffmmm", #m=MBConv,f=FusedMBConv,u=Unet 
                    "act": "silu",
                    "res_mode": "cat", #cat, add
                    "init_mode": "effecientnetv2",
                    "downscale_mode": "avgpool",
                    "upscale_mode": "bilinear",
                    "input_channels": 4,
                    "output_channels": 3,
                    "num_blocks": 5,
                    "num_c": [8,16,32,48,64],
                    "num_repeat": [1,2,2,4,4],
                    "expand_ratio": [1,4,4,6,6],
                    "SE": [0,0,1,1,1]
                }}

args = {"unet": {"block": "ffmmm", #m=MBConv,f=FusedMBConv,u=Unet 
                    "act": "silu",
                    "res_mode": "cat", #cat, add
                    "init_mode": "effecientnetv2",
                    "downscale_mode": "avgpool",
                    "upscale_mode": "bilinear",
                    "input_channels": 4,
                    "output_channels": 3,
                    "num_blocks": 4,
                    "num_c": [16,32,64,96],
                    "num_repeat": [1,2,2,4],
                    "expand_ratio": [1,4,4,6],
                    "SE": [0,1,1,1]
                }}
args_n = json.loads(json.dumps(args), object_hook=lambda item: SimpleNamespace(**item))

In [3]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
net = AbstractUNet(args_n).to(device)

In [4]:
test = torch.zeros((1,4,360,360)).to(device)
output = net(test)
output.shape

torch.Size([1, 3, 360, 360])

In [5]:
settings = {
	"team_data":
	{
		"name": "YourTeamNameHere"
	},
    "training_parms": {
		"method": "YourMethodHere",
		"dummy_value": 1000
    },
    "dirs": {
            "input_data_dir": "C:/Users/lowes/OneDrive/Skrivebord/DTU/summer_school_23/MissingDataChallenge/data/",
			"output_data_dir": "missing_data_output/"
	},
	"challenge_server": {
		"address": "http://fungi.compute.dtu.dk:8080"
	},
	"data_set": "training",
	"batch_size": 4,
	"num_workers": 0,
}


dataset_train = CatDataset(settings)
print([d.shape for d in dataset_train[0]])

dl = DataLoader(dataset_train, 
				batch_size=settings["batch_size"],
                shuffle=True,
                num_workers=settings["num_workers"])

[torch.Size([4, 360, 360]), torch.Size([1, 360, 360]), torch.Size([3, 360, 360])]


## Train

In [6]:
from inpaint_tools import save_inp_out, save_image
settings["data_set"] = "training"
dataset_train = CatDataset(settings)
dl_train = DataLoader(dataset_train, 
				batch_size=settings["batch_size"],
                shuffle=True,
                num_workers=settings["num_workers"])

settings["data_set"] = "validation_200"
dataset_val = CatDataset(settings)
dl_val = DataLoader(dataset_val, 
				batch_size=settings["batch_size"],
                shuffle=True,
                num_workers=settings["num_workers"])

# Create the model, loss function, and optimizer
save_every = 4
NUM_EPOCHS = 12

device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = AbstractUNet(args_n).to(device)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
val_loss = "NA"
save_dir = os.path.join(settings["dirs"]["output_data_dir"], "efficient_unet")
save_image_dir = os.path.join(save_dir,"train_images")
pathlib.Path(save_image_dir).mkdir(parents=True, exist_ok=True)

# Training loop
for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0

    # Create a progress bar using tqdm
    with tqdm.tqdm(dl_train, unit="batch") as tepoch:
        for batch in tepoch:
            model_input, mask, image = batch
            model_input, mask, image = model_input.to(device), mask.to(device), image.to(device)
            
            optimizer.zero_grad()

            # Forward pass
            outputs = model(model_input)

            loss = criterion(outputs*mask, image*mask)

            # Backpropagation and optimization
            loss.backward()
            optimizer.step()

            # Update progress bar description with loss
            tepoch.set_description(f"Epoch {epoch+1}")
            tepoch.set_postfix(loss=loss.item())

            # Accumulate loss for this epoch
            running_loss += loss.item()

    # Calculate and print average loss for the epoch
    average_loss = running_loss / len(dl_train)

    model.eval()
    running_val = 0.0
    with torch.no_grad():
        for batch in dl_val:
            model_input, mask, image = batch
            model_input, mask, image = model_input.to(device), mask.to(device), image.to(device)

            outputs = model(model_input)

            loss = criterion(outputs*mask, image*mask)
            running_val += loss.item()

    avg_val = running_val / len(dl_val)
    print("validation loss: ", avg_val)
    
    save_image_name = os.path.join(save_image_dir,f"val_epoch_{epoch+1}.png")
    save_image(save_image_name,outputs[0],mask[0],image[0])
    save_inp_out(save_image_name, outputs[0],mask[0],model_input[0])

    if (epoch+1)%save_every:
        save_net = os.path.join(save_dir,"models",f"epoch_{epoch+1}.pt")
        torch.save(model.state_dict(), save_net)



Epoch 1:  20%|█▉        | 248/1242 [01:12<04:51,  3.41batch/s, loss=0.0381]

In [None]:
# save_net = os.path.join(save_dir,"models",f"epoch_{epoch+1}.pt")
# pathlib.Path(os.path.join(save_dir,"models")).mkdir(parents=True, exist_ok=True)
# torch.save(model.state_dict(), save_net)

## Testing

In [None]:
epoch_test = 14 #epoch + 1
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = AbstractUNet(args_n).to(device)
save_net = os.path.join(save_dir,"models",f"epoch_{epoch_test}.pt")
model.load_state_dict(torch.load(save_net))
model.eval()
print("model loaded")

AbstractUNet(
  (first_conv): Conv2d(4, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (last_conv): Conv2d(8, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (DownBlocks): ModuleList(
    (0): Sequential(
      (0): FusedMBConv(
        (conv): Sequential(
          (0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): SiLU()
          (3): Conv2d(8, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (4): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
    )
    (1): Sequential(
      (0): FusedMBConv(
        (conv): Sequential(
          (0): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): SiLU()
          (3): Conv

In [None]:
from inpaint_tools import save_test_image
from dataset import CatDataset

settings["data_set"] = "test_200"

save_test_dir = os.path.join(save_dir,settings["data_set"])
pathlib.Path(save_test_dir).mkdir(parents=True, exist_ok=True)

dataset_test = CatDataset(settings, test=True)
dl_test = DataLoader(dataset_test, 
				batch_size=settings["batch_size"],
                shuffle=True,
                num_workers=settings["num_workers"])
with torch.no_grad():
    for batch in dl_test:
        model_input, mask, im_id = batch
        model_input, mask = model_input.to(device), mask.to(device)

        outputs = model(model_input)
        
        for i in range(len(im_id)):
            save_image_name = os.path.join(save_test_dir,f"{im_id[i]}.png")
            save_test_image(save_image_name,outputs[i],mask[i],model_input[i])
    

