In [None]:
%pip install piq pydicom gdown
!gdown 1CRa7U3CLsc8NosfnyCUVxYHbI0V2JVdX
!unzip -q Project.zip

In [3]:
from piq import ssim, SSIMLoss, psnr, haarpsi, vsi, ms_ssim
import numpy as np
from matplotlib import pyplot as plt
import torch
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
from torch import nn
from torch.utils.data import DataLoader, Dataset
import os
from collections import OrderedDict
from tqdm.auto import trange, tqdm
from pydicom import dcmread #reads .dcm file having path to it
# import wandb
# os.environ['WANDB_API_KEY'] = '..'

In [4]:
# Based on https://arxiv.org/abs/2006.09661

class SineLayer(nn.Module): 
    def __init__(self, in_features, out_features, bias=True,
                 is_first=False, omega_0=30):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first
        
        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        
        self.init_weights()
    
    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features, 
                                             1 / self.in_features)      
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, 
                                             np.sqrt(6 / self.in_features) / self.omega_0)
        
    def forward(self, input):
        return torch.sin(self.omega_0 * self.linear(input))
    
    
class Siren(nn.Module):
    def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=False, 
                 first_omega_0=30, hidden_omega_0=30.):
        super().__init__()
        
        self.net = []
        self.net.append(SineLayer(in_features, hidden_features, 
                                  is_first=True, omega_0=first_omega_0))

        for i in range(hidden_layers):
            self.net.append(SineLayer(hidden_features, hidden_features, 
                                      is_first=False, omega_0=hidden_omega_0))

        if outermost_linear:
            final_linear = nn.Linear(hidden_features, out_features)
            
            with torch.no_grad():
                final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0, 
                                              np.sqrt(6 / hidden_features) / hidden_omega_0)
                
            self.net.append(final_linear)
        else:
            self.net.append(SineLayer(hidden_features, out_features, 
                                      is_first=False, omega_0=hidden_omega_0))
        
        self.net = nn.Sequential(*self.net)
    
    def forward(self, coords):
        coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input
        output = self.net(coords)
        return output

class SirenCascade(nn.Module):
    def __init__(self, in_features, hidden_features, hidden_layers, out_features,
                       outermost_linear=False, parts=10, split_index=-1,
                       first_omega_0=30, hidden_omega_0=30.):
        super().__init__()
        self.parts_models = nn.ModuleList()
        self.parts = parts
        self.split_index = split_index

        for _ in range(parts):
            self.parts_models.append(Siren(in_features, hidden_features, hidden_layers, out_features,
                                    outermost_linear, first_omega_0, hidden_omega_0))

    def forward(self, coords):
        # Assume whole batch item is in the same part
        predictions = []
        for btch_idx in range(coords.size(0)):
            scaled_coord = (coords[btch_idx, 0, self.split_index] / 2 + .5) * self.parts
            part_idx = torch.floor(scaled_coord).type(torch.int32)
            inner_coord = (scaled_coord - part_idx) * 2 - 1
            coords[btch_idx, :, self.split_index] = inner_coord
            predictions.append(self.parts_models[part_idx](coords[btch_idx:btch_idx + 1, :, :]))
        return torch.cat(predictions, dim=0)


In [12]:
class DcmProjections(Dataset):
    def __init__(self, dir_path):
        super().__init__()
        self.dir_path = dir_path
        self.file_list = os.listdir(dir_path)

        dcm = self.dir_path + '/' + self.file_list[0]
        image = dcmread(dcm).pixel_array
        tensors = []
        for side in image.shape:
            tensors.append(torch.linspace(-1, 1, steps=side))
        mgrid = torch.stack(torch.meshgrid(*tensors), dim=-1)
        self.coords = mgrid.reshape(-1, len(image.shape))

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        dcm = self.dir_path + '/' + self.file_list[idx]
        image = dcmread(dcm).pixel_array

        min_val = 0
        max_val = 65535
        
        coords = torch.cat([self.coords, torch.ones((self.coords.shape[0],1)) / len(self.file_list) * 2 * idx - 1], dim=1)
        return coords, torch.Tensor(image.astype(np.float64)) / (max_val-min_val) * 2 - 1

In [42]:
proj_data = DcmProjections('./С004_for_training/train/')
index = list(range(1, len(proj_data), 6))
index_val = list(range(4, len(proj_data), 6))
proj_data_small = torch.utils.data.Subset(proj_data, index)
proj_data_small_val = torch.utils.data.Subset(proj_data, index_val)
# dataloader = DataLoader(proj_data, 
#                         batch_size=4,
#                         num_workers=2,
#                         shuffle=True,
#                         persistent_workers=True)
dataloader_small = DataLoader(proj_data_small,
                              batch_size=2,
                              num_workers=2,
                              shuffle=True,
                              persistent_workers=True)
dataloader_small_val = DataLoader(proj_data_small_val,
                                  batch_size=1,
                                  num_workers=2,
                                  shuffle=False,
                                  persistent_workers=True)
# print(proj_data[0][0].shape)
# proj_data[0][0][0, :], proj_data[0][0][1200, :]

In [5]:
def PSNR(x,y, data_range = 1.0):
    return psnr(x,y, data_range = data_range)

def SSIM(x, y, data_range = 1.0):
    return ssim(x, y, data_range = data_range)

def HaarsPSI(x,y, data_range =1.0):
    return haarpsi(x,y,data_range=data_range)

def VSI(x, y, data_range =1.0):
    return vsi(x, y, data_range=data_range)

In [None]:
def train(model, dataloader, val_loader, lr=5e-5, total_steps=30):
    optim = torch.optim.Adam(lr=lr, params=model.parameters())
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

    for step in trange(total_steps):
        # model_output, coords = model(model_input)    
        # loss = ((model_output - ground_truth)**2).mean()
        model.train()
        for batch in tqdm(dataloader, desc='Train', leave=False):
            model_input, ground_truth = batch
            model_input = model_input.to(DEVICE)
            ground_truth = ground_truth.to(DEVICE)
            model_output = model(model_input)
            loss = ((model_output.view(*ground_truth.shape) - ground_truth)**2).mean()
            optim.zero_grad()
            loss.backward()
            optim.step()
            # wandb.log({'train_loss': loss.item()})
        model.eval()
        val_loss = 0
        psnr = 0
        ssim = 0
        haarpsi = 0
        vsi = 0
        for batch in tqdm(val_loader, desc="Validate", leave=False):
            model_input, ground_truth = batch
            model_input = model_input.to(DEVICE)
            ground_truth = ground_truth.to(DEVICE).view(-1, 1, ground_truth.size(-2), ground_truth.size(-1))
            with torch.no_grad():
                model_output = model(model_input)
                model_output = torch.clamp(model_output, -1, 1)
                val_loss += ((model_output.view(*ground_truth.shape) - ground_truth)**2).mean()
                psnr += PSNR(model_output.view(*ground_truth.shape) / 2 + .5, ground_truth / 2 + .5)
                ssim += SSIM(model_output.view(*ground_truth.shape) / 2 + .5, ground_truth / 2 + .5)
                haarpsi += HaarsPSI(model_output.view(*ground_truth.shape) / 2 + .5, ground_truth / 2 + .5)
                vsi += VSI(model_output.view(*ground_truth.shape) / 2 + .5, ground_truth / 2 + .5)
        # wandb.log({'val_loss': val_loss / len(val_loader),
        #            'PSNR': psnr / len(val_loader),
        #            'SSIM': ssim / len(val_loader),
        #            'HaarsPSI': haarpsi / len(val_loader),
        #            'VSI': vsi / len(val_loader)})
        print({'val_loss': val_loss / len(val_loader),
               'PSNR': psnr / len(val_loader),
               'SSIM': ssim / len(val_loader),
               'HaarsPSI': haarpsi / len(val_loader),
               'VSI': vsi / len(val_loader)})
        

            # img_grad = gradient(model_output, coords)
            # img_laplacian = laplace(model_output, coords)

            # fig, axes = plt.subplots(1,3, figsize=(18,6))
            # axes[0].imshow(model_output.cpu().view(256,256).detach().numpy())
            # axes[1].imshow(img_grad.norm(dim=-1).cpu().view(256,256).detach().numpy())
            # axes[2].imshow(img_laplacian.cpu().view(256,256).detach().numpy())
            # plt.show()

def save_model(model, args, save_path):
    torch.save({
        'model_state_dict': model.state_dict(),
        'args': args,
    }, save_path)

def experiment(args=None):
    torch.cuda.empty_cache()
    # if args is not None:
    #     wandb.init(project='CT-compression-siren', config=args)
    # else:
    #     wandb.init()
    # config = wandb.config
    config = args
    if config['model_type'] == 'siren_cascade':
        model = SirenCascade(in_features=3,
                             out_features=1,
                             parts=config['parts'],
                             split_index=2,
                             hidden_features=config['hidden_features'],
                             hidden_layers=config['hidden_layers'],
                             outermost_linear=config['last_linear'])
    elif config['model_type'] == 'siren':
        model = Siren(in_features=3,
                      out_features=1,
                      hidden_features=config['hidden_features'],
                      hidden_layers=config['hidden_layers'],
                      outermost_linear=config['last_linear'])
    print("Parameters: ", sum(p.numel() for p in model.parameters()))
    model.to(DEVICE)
    args = {
      'hidden_features': config['hidden_features'],
      'hidden_layers': config['hidden_layers'],
      'lr': config['lr'], 
      'last_linear': config['last_linear'],
      'model_type': config['model_type'],
      'parts': config['parts']}
    try:
        train(model, dataloader_small, dataloader_small_val, lr=config['lr'])
    except KeyboardInterrupt:
        print("FINISHING...")
    finally:
        file_path = './model.ptch'
        save_model(model, args, file_path)
        print(f'model saved as {file_path}')
        # wandb.log_artifact(file_path, name=wandb.config['model_type'], type='model')
    # wandb.finish()


In [1]:
# sweep_id = ""
# count = 10
# wandb.agent(sweep_id, function=experiment, count=count, project='CT-compression-siren')

In [None]:
args = {
  'hidden_features': 128,
  'hidden_layers': 3,
  'lr': 3e-5,
  'last_linear': True,
  'model_type': 'siren_cascade', # or 'siren'
  'parts': 20
}
experiment(args=args)

{'hidden_features': 128, 'hidden_layers': 3, 'lr': 3e-05, 'last_linear': True, 'model_type': 'siren_cascade', 'parts': 20}


VBox(children=(Label(value='3.897 MB of 3.897 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

{'hidden_features': 128, 'hidden_layers': 3, 'lr': 3e-05, 'last_linear': True, 'model_type': 'siren_cascade', 'parts': 20}
Parameters:  1003540


  0%|          | 0/30 [00:00<?, ?it/s]

Train:   0%|          | 0/859 [00:00<?, ?it/s]

Validate:   0%|          | 0/1716 [00:00<?, ?it/s]



Train:   0%|          | 0/859 [00:00<?, ?it/s]

Validate:   0%|          | 0/1716 [00:00<?, ?it/s]