In [55]:
%load_ext autoreload
%autoreload 2
from types import SimpleNamespace
import json
import os
import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from models.efficient_unet import AbstractUNet
from dataset import CatDataset
from inpaint_tools import read_file_list
from skimage import io
import numpy as np
import pathlib

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [56]:
args = {"unet": {"block": "ffmmm", #m=MBConv,f=FusedMBConv,u=Unet 
                    "act": "silu",
                    "res_mode": "cat", #cat, add
                    "init_mode": "effecientnetv2",
                    "downscale_mode": "avgpool",
                    "upscale_mode": "bilinear",
                    "input_channels": 4,
                    "output_channels": 3,
                    "num_blocks": 5,
                    "num_c": [8,16,32,48,64],
                    "num_repeat": [1,2,2,4,4],
                    "expand_ratio": [1,4,4,6,6],
                    "SE": [0,0,1,1,1]
                }}

args = {"unet": {"block": "ffmmm", #m=MBConv,f=FusedMBConv,u=Unet 
                    "act": "silu",
                    "res_mode": "cat", #cat, add
                    "init_mode": "effecientnetv2",
                    "downscale_mode": "avgpool",
                    "upscale_mode": "bilinear",
                    "input_channels": 4,
                    "output_channels": 3,
                    "num_blocks": 4,
                    "num_c": [8,16,32,64],
                    "num_repeat": [1,2,2,4],
                    "expand_ratio": [1,4,4,6],
                    "SE": [0,0,1,1]
                }}
args_n = json.loads(json.dumps(args), object_hook=lambda item: SimpleNamespace(**item))

In [57]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
net = AbstractUNet(args_n).to(device)

In [58]:
test = torch.zeros((1,4,360,360)).to(device)
output = net(test)
output.shape

torch.Size([1, 3, 360, 360])

In [59]:
settings = {
	"team_data":
	{
		"name": "YourTeamNameHere"
	},
    "training_parms": {
		"method": "YourMethodHere",
		"dummy_value": 1000
    },
    "dirs": {
            "input_data_dir": "C:/Users/lowes/OneDrive/Skrivebord/DTU/summer_school_23/MissingDataChallenge/data/",
			"output_data_dir": "missing_data_output/"
	},
	"challenge_server": {
		"address": "http://fungi.compute.dtu.dk:8080"
	},
	"data_set": "training",
	"batch_size": 4,
	"num_workers": 0,
}


dataset_train = CatDataset(settings)
print([d.shape for d in dataset_train[0]])

dl = DataLoader(dataset_train, 
				batch_size=settings["batch_size"],
                shuffle=True,
                num_workers=settings["num_workers"])

[torch.Size([4, 360, 360]), torch.Size([1, 360, 360]), torch.Size([3, 360, 360])]


In [60]:
import copy
def save_image(save_name, output, mask, image):
    
    output = output.permute(1,2,0).detach().cpu().numpy()
    mask = mask.squeeze().cpu().numpy()
    image = image.permute(1,2,0).cpu().numpy()
    mask = np.stack((mask,mask,mask),-1)
    combined = copy.deepcopy(image)
    combined[mask.astype(np.bool)] = output[mask.astype(np.bool)]
    print(output.shape)
    print(mask.shape)
    print(image.shape)
    print(combined.shape)
    arr = np.concatenate((image,mask,output,combined),1)
    arr = np.clip(arr,0,1)
    io.imsave(save_name, arr)

## Train

In [61]:
dataset_train = CatDataset(settings)
dl_train = DataLoader(dataset_train, 
				batch_size=settings["batch_size"],
                shuffle=True,
                num_workers=settings["num_workers"])

settings["data_set"] = "validation_200"
dataset_val = CatDataset(settings)
dl_val = DataLoader(dataset_val, 
				batch_size=settings["batch_size"],
                shuffle=True,
                num_workers=settings["num_workers"])

# Create the model, loss function, and optimizer
NUM_EPOCHS = 100

device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = AbstractUNet(args_n).to(device)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
val_loss = "NA"
save_dir = os.path.join(settings["dirs"]["output_data_dir"], "efficient_unet")
save_image_dir = os.path.join(save_dir,"train_images")
pathlib.Path(save_image_dir).mkdir(parents=True, exist_ok=True)

# Training loop
for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0

    # Create a progress bar using tqdm
    with tqdm.tqdm(dl_train, unit="batch") as tepoch:
        for batch in tepoch:
            model_input, mask, image = batch
            model_input, mask, image = model_input.to(device), mask.to(device), image.to(device)
            
            optimizer.zero_grad()

            # Forward pass
            outputs = model(model_input)

            loss = criterion(outputs*mask, image*mask)

            # Backpropagation and optimization
            loss.backward()
            optimizer.step()

            # Update progress bar description with loss
            tepoch.set_description(f"Epoch {epoch+1}")
            tepoch.set_postfix(loss=loss.item())

            # Accumulate loss for this epoch
            running_loss += loss.item()

    # Calculate and print average loss for the epoch
    average_loss = running_loss / len(dl_train)

    model.eval()
    running_val = 0.0

    for batch in dl_val:
        model_input, mask, image = batch
        model_input, mask, image = model_input.to(device), mask.to(device), image.to(device)

        outputs = model(model_input)

        loss = criterion(outputs*mask, image*mask)
        running_val += loss.item()

    avg_val = running_val / len(dl_val)
    print("validation loss: ", avg_val)
    
    save_image_name = os.path.join(save_image_dir,f"val_epoch_{epoch}.png")
    save_image(save_image_name,outputs[0],mask[0],image[0])



Epoch 1: 100%|██████████| 1242/1242 [05:07<00:00,  4.03batch/s, loss=0.00687]
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  combined[mask.astype(np.bool)] = output[mask.astype(np.bool)]


validation loss:  0.0024539827648550274
(360, 360, 3)
(360, 360, 3)
(360, 360, 3)
(360, 360, 3)


Epoch 2:  21%|██        | 256/1242 [00:53<03:49,  4.29batch/s, loss=0.00363]