# Develop a light GAN for post processing
* The input must be the biggest size in the dataset
* Both Generator and Discriminator must fit in the GPU (48GB VRAM)


In [7]:
# Find the bigest shape
import os
import tifffile as tiff
import numpy as np
all_shapes = []
root_dataset = "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/DataSet/Challenge_dataset_updated/train_images"
for file_name in os.listdir(root_dataset):
    if file_name.endswith('.tif'):
        file_path = os.path.join(root_dataset, file_name)
        img_array = tiff.imread(file_path)
        all_shapes.append(img_array.shape)
        if img_array.shape[0]>320:
            print(file_path)

/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/DataSet/Challenge_dataset_updated/train_images/3866421231.tif


In [10]:
x_all_shapes = [shape_element[0] for shape_element in all_shapes]
y_all_shapes = [shape_element[1] for shape_element in all_shapes]
z_all_shapes = [shape_element[2] for shape_element in all_shapes]

print(f"Biggest x: {max(x_all_shapes)}")
print(f"Biggest y: {max(y_all_shapes)}")
print(f"Biggest z: {max(z_all_shapes)}")

Biggest x: 384
Biggest y: 384
Biggest z: 384


In [11]:
import json
import sys
sys.path.append("../utils")
from main_train_class import main_train_STU_Net

class postprocessGANs(main_train_STU_Net):
    def __init__(self):
        pass

In [12]:
import torch
import torch.nn as nn

class Generator(torch.nn.Module):
    def __init__(self, in_channels=512, first_channels=32, out_channels=1):
        super(Generator, self).__init__()
        self.in_channels = in_channels
        self.first_channels = first_channels
        self.out_channels = out_channels
        self.activation = torch.nn.Tanh()

        #########################
        ######## Encoder ########
        self.enc1 = nn.Sequential(
            nn.Conv3d(
                in_channels=self.in_channels, out_channels=self.first_channels, kernel_size=(5, 5, 5),
                stride=1, padding=1, bias=False
            ),
            nn.InstanceNorm3d(self.first_channels),
            nn.LeakyReLU(inplace=True)
        )

        self.enc2 = nn.Sequential(
            nn.Conv3d(
                in_channels=self.first_channels, out_channels=self.first_channels*2, kernel_size=(5, 5, 5),
                stride=1, padding=1, bias=False
            ),
            nn.InstanceNorm3d(self.first_channels),
            nn.LeakyReLU(inplace=True)
        )

        #########################
        ######## Decoder ########
        self.dec1 = nn.Sequential(
            nn.ConvTranspose3d(
                in_channels=self.first_channels*2, out_channels=self.first_channels, kernel_size=(6, 6, 6),
                stride=2, padding=1, bias=False
            ),
            nn.InstanceNorm3d(self.first_channels),
            nn.LeakyReLU(inplace=True)
        )

        self.dec2 = nn.Sequential(
            nn.ConvTranspose3d(
                in_channels=self.first_channels+self.first_channels, out_channels=self.first_channels, kernel_size=(6, 6, 6),
                stride=2, padding=1, bias=False
            ),
            nn.InstanceNorm3d(self.first_channels),
            nn.LeakyReLU(inplace=True)
        )

        self.out = nn.Sequential(
            nn.Conv3d(
                in_channels=self.first_channels, out_channels=self.out_channels, kernel_size=(5, 5, 5),
                stride=1, padding=1, bias=False
            )
        )


    def forward(self, x):
        # Encoder
        x_enc1 = self.enc1(x) # First encoder (skip)
        x_enc2 = self.enc2(x_enc1) # Second encoder
        # Decoder
        x_dec1 = self.dec1(x_enc2)
        x_dec2 = self.dec2(torch.cat(x_enc1, x_dec1))
        x_out = self.out(x_dec2)
        # Activation
        x_out = self.activation(x_out)
        return x_out

In [16]:
# Test generator
generator = Generator(in_channels=2, first_channels=32, out_channels=1)
generator = generator.cuda()
## Simulate input shape
rand_tensor = torch.randn(1, 1, 320, 320, 320)
rand_coarse_mask = torch.randn(1, 1, 320, 320, 320)
input_tensor = torch.cat([rand_tensor, rand_coarse_mask], dim=1)
input_tensor = input_tensor.cuda()
#Run Generator
generator(input_tensor)

OutOfMemoryError: CUDA out of memory. Tried to allocate 3.83 GiB. GPU 0 has a total capacity of 47.41 GiB of which 1.13 GiB is free. Process 1092925 has 1.92 GiB memory in use. Process 1154661 has 43.80 GiB memory in use. Process 1256514 has 532.00 MiB memory in use. Of the allocated memory 254.40 MiB is allocated by PyTorch, and 17.60 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)