# Medical image analysis with PyTorch

Create a deep convolutional network for an image translation task with PyTorch from scratch and train it on a subset of the IXI dataset for a T1-w to T2-w transformation.

## Step 4: Train the network

In [0]:
valid_split = 0.1
batch_size = 16
n_jobs = 12
n_epochs = 50

In [0]:
tfms = Compose([RandomCrop3D((128,128,32)), ToTensor()])

# set up training and validation data loader for nifti images
dataset = NiftiDataset(t1_dir, t2_dir, tfms, preload=False)  # set preload=False if you have limited CPU memory
num_train = len(dataset)
indices = list(range(num_train))
split = int(valid_split * num_train)
valid_idx = np.random.choice(indices, size=split, replace=False)
train_idx = list(set(indices) - set(valid_idx))
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = DataLoader(dataset, sampler=train_sampler, batch_size=batch_size,
                          num_workers=n_jobs, pin_memory=True)
valid_loader = DataLoader(dataset, sampler=valid_sampler, batch_size=batch_size,
                          num_workers=n_jobs, pin_memory=True)

### Milestone 4

In [0]:
assert torch.cuda.is_available()
device = torch.device('cuda:0')
torch.backends.cudnn.benchmark = True

In [0]:
model = Unet()

In [0]:
#model.load_state_dict(torch.load('trained.pth'));

In [0]:
model.cuda(device=device)
optimizer = torch.optim.AdamW(model.parameters(), weight_decay=1e-6)
criterion = nn.SmoothL1Loss()  #nn.MSELoss()

In [0]:
train_losses, valid_losses = [], []
n_batches = len(train_loader)
for t in range(1, n_epochs + 1):
    # training
    t_losses = []
    model.train(True)
    for i, (src, tgt) in enumerate(train_loader):
        src, tgt = src.to(device), tgt.to(device)
        optimizer.zero_grad()
        out = model(src)
        loss = criterion(out, tgt)
        t_losses.append(loss.item())
        loss.backward()
        optimizer.step()
    train_losses.append(t_losses)

    # validation
    v_losses = []
    model.train(False)
    with torch.set_grad_enabled(False):
        for src, tgt in valid_loader:
            src, tgt = src.to(device), tgt.to(device)
            out = model(src)
            loss = criterion(out, tgt)
            v_losses.append(loss.item())
        valid_losses.append(v_losses)

    if not np.all(np.isfinite(t_losses)): 
        raise RuntimeError('NaN or Inf in training loss, cannot recover. Exiting.')
    log = f'Epoch: {t} - Training Loss: {np.mean(t_losses):.2e}, Validation Loss: {np.mean(v_losses):.2e}'
    print(log)

Epoch: 1 - Training Loss: 1.44e-01, Validation Loss: 8.80e-01
Epoch: 2 - Training Loss: 1.20e-01, Validation Loss: 2.06e-01
Epoch: 3 - Training Loss: 1.17e-01, Validation Loss: 1.63e-01
Epoch: 4 - Training Loss: 1.09e-01, Validation Loss: 1.21e-01
Epoch: 5 - Training Loss: 1.05e-01, Validation Loss: 1.01e-01
Epoch: 6 - Training Loss: 1.01e-01, Validation Loss: 1.06e-01
Epoch: 7 - Training Loss: 1.04e-01, Validation Loss: 3.73e-01
Epoch: 8 - Training Loss: 1.02e-01, Validation Loss: 1.15e-01
Epoch: 9 - Training Loss: 1.01e-01, Validation Loss: 1.97e-01
Epoch: 10 - Training Loss: 9.05e-02, Validation Loss: 1.41e-01
Epoch: 11 - Training Loss: 8.75e-02, Validation Loss: 9.73e-02
Epoch: 12 - Training Loss: 9.25e-02, Validation Loss: 1.21e-01
Epoch: 13 - Training Loss: 8.07e-02, Validation Loss: 1.14e-01
Epoch: 14 - Training Loss: 9.38e-02, Validation Loss: 1.12e-01
Epoch: 15 - Training Loss: 9.20e-02, Validation Loss: 8.68e-02
Epoch: 16 - Training Loss: 8.71e-02, Validation Loss: 9.11e-02
E

In [0]:
torch.save(model.state_dict(), 'trained.pth')