# Init

In [2]:
def is_notebook() -> bool:
    try:
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            return True  # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False

SCRIPT_MODE = True
if is_notebook():
    SCRIPT_MODE = False


## Imports

In [3]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch
import matplotlib.pyplot as plt
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from PIL import Image, ImageOps
import wandb
from custom_losses import DiceBCELoss

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('using: ', device)

using:  cpu


#Consts


In [5]:
optimizers = {
    "SGD": optim.SGD,
    "Adam": optim.Adam,
    "RMSprop": optim.RMSprop
}

criterias = {
    "CrossEntropyLoss": nn.CrossEntropyLoss,
    "MSELoss": nn.MSELoss,
    "BCEWithLogits": nn.BCEWithLogitsLoss,
    "DiceBCELoss": DiceBCELoss
}

# Connect to weight and biases server

In [6]:
wandb.login()
params = {
    'optimizer': "SGD",
    'optim.learning_rate': 0.005,
    'optim.momentum': 0.9,

    'criteria': "DiceBCELoss",
    'criteria.DiceBCELoss': {'dice_weight': 1},

    'epochs': 3000,
    'batch_size': 32
}

wandb.init(project="Overfitting-MRI-Imaging", config=params, name="UNet-3D-DiceBCE-Fed", entity='hilit')
config = wandb.config

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /Users/mac/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mhilit[0m. Use [1m`wandb login --relogin`[0m to force relogin


# Load Data


In [None]:
from custom_datasets import Promise12, PROSTATEx, MedicalSegmentationDecathlon, NciIsbi2013


class OverfittingPromise12(Promise12):
    def __init__(self, root_dir, transform=None):
        super().__init__(root_dir, transform)
        self.processed_files = {k: v.replace("Overfitting", "") for k, v in self.processed_files.items()}

    def __getitem__(self, index):
        return super().__getitem__(0)

# TODO: Create overfitting datasets (classes) for the rest of them

dataset = OverfittingPromise12(root_dir='data', transform=transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_size=config["batch_size"],
    shuffle=True)
test_loader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_size=config["batch_size"],
    shuffle=False)

# Explore Data

In [None]:
if not SCRIPT_MODE:
    scan_id = 1
    slice_id = 8
    fig, ax = plt.subplots(1, 2, figsize=(10, 10))
    ax[0].imshow(dataset[scan_id][0][slice_id, :, :], cmap='gray')
    ax[1].imshow(dataset[scan_id][1][slice_id, :, :], cmap='gray')

In [None]:
if not SCRIPT_MODE:
    scan_id = 1
    slice_id = 14
    fig, ax = plt.subplots(1, 2, figsize=(10, 10))
    ax[0].imshow(dataset[scan_id][0][slice_id, :, :], cmap='gray')
    ax[1].imshow(dataset[scan_id][1][slice_id, :, :], cmap='gray')

# Model

In [None]:
from models import CNNTarget
model = CNNTarget(in_channels=15, out_channels=15, features=[4, 8, 16, 32]).to(device)

In [None]:
optimizer = optimizers[config['optimizer']](model.parameters(),
                      lr=config['optim.learning_rate'],
                      momentum=config['optim.momentum'])

criteria = criterias[config['criteria']](**config['criteria.' + config['criteria']])
print(model)

# Training