In [2]:
!pip install torchmetrics
!pip install pytorch-lightning

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [3]:
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import pytorch_lightning as pl
import torchmetrics

import sys
import gc

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import os
import numpy as np
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import random

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [5]:
class SegYDataset(Dataset):
    def __init__(self, data_type=None, size_scale=1):
        self.fault_data = np.load('/content/drive/MyDrive/reduced_fault_data.npy')
        self.seis_data = np.load('/content/drive/MyDrive/reduced_seis_data.npy')
        self.size_scale = size_scale
    
    def __len__(self):
        return 100
    
    def __getitem__(self, idx):
        while True:
            size = 128
            IL, XL, Z = self.seis_data.shape
            iline = random.randint(0, IL - size)
            xline = random.randint(0, XL - size)
            zline = random.randint(0, Z - size)
            seis_slice = self.seis_data[iline: iline+size,
                                        xline: xline+size,
                                        zline: zline+size,]

            fault_slice = self.fault_data[iline: iline+size,
                                          xline: xline+size,
                                          zline: zline+size,]
            if fault_slice.sum() > 70_000:
                X = torch.Tensor(seis_slice)
                Y = torch.Tensor(fault_slice)
                return (X[None,:], Y[None, :])
                

In [6]:
train_dataset = SegYDataset()
trainloader = DataLoader(train_dataset, batch_size=1)

In [11]:
class UNet3D(nn.Module):
    def __init__(self):
        super(UNet3D, self).__init__()
        
        self.layer_encoder_1 = nn.Sequential(nn.Conv3d(1, 32, kernel_size=3, stride=1, padding=0),
                                             nn.BatchNorm3d(32),
                                             nn.ReLU())

        self.layer_encoder_2 = nn.Sequential(nn.Conv3d(32, 64, kernel_size=3, stride=1, padding=0),
                                             nn.BatchNorm3d(64),
                                             nn.ReLU())

        self.layer_encoder_3 = nn.Sequential(nn.Conv3d(64, 64, kernel_size=3, stride=1, padding=0),
                                             nn.BatchNorm3d(64),
                                             nn.ReLU())

        self.layer_encoder_4 = nn.Sequential(nn.Conv3d(64, 128, kernel_size=3, stride=1, padding=0),
                                             nn.BatchNorm3d(128),
                                             nn.ReLU())
        self.layer_encoder_5 = nn.Sequential(nn.Conv3d(128, 128, kernel_size=3, stride=1, padding=0),
                                             nn.BatchNorm3d(128),
                                             nn.ReLU())
        self.layer_encoder_6 = nn.Sequential(nn.Conv3d(128, 256, kernel_size=3, stride=1, padding=0),
                                             nn.BatchNorm3d(256),
                                             nn.ReLU())
        self.layer_encoder_7 = nn.Sequential(nn.Conv3d(256, 256, kernel_size=3, stride=1, padding=0),
                                             nn.BatchNorm3d(256),
                                             nn.ReLU())
        self.layer_encoder_8 = nn.Sequential(nn.Conv3d(256, 512, kernel_size=3, stride=1, padding=0),
                                             nn.BatchNorm3d(512),
                                             nn.ReLU())
        
        self.max_pool_1 = nn.MaxPool3d(2)
        self.max_pool_2 = nn.MaxPool3d(2)
        self.max_pool_3 = nn.MaxPool3d(2)
        
        self.layer_decoder_1 = nn.Sequential(nn.ConvTranspose3d(512, 512, kernel_size=2, stride=2, padding=0, bias=False),
                                             nn.ReLU())
        self.layer_decoder_2 = nn.Sequential(nn.ConvTranspose3d(256 + 512, 256, kernel_size=3, stride=1, padding=1, bias=False),
                                             nn.ReLU())
        self.layer_decoder_3 = nn.Sequential(nn.ConvTranspose3d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
                                             nn.ReLU())
        
        self.layer_decoder_4 = nn.Sequential(nn.ConvTranspose3d(256, 256, kernel_size=2, stride=2, padding=0, bias=False),
                                             nn.ReLU())
        self.layer_decoder_5 = nn.Sequential(nn.ConvTranspose3d(128 + 256, 128, kernel_size=3, stride=1, padding=1, bias=False),
                                             nn.ReLU())
        self.layer_decoder_6 = nn.Sequential(nn.ConvTranspose3d(128, 128, kernel_size=3, stride=1, padding=1, bias=False),
                                             nn.ReLU())
        
        self.layer_decoder_7 = nn.Sequential(nn.ConvTranspose3d(128, 128, kernel_size=2, stride=2, padding=0, bias=False),
                                             nn.ReLU())
        self.layer_decoder_8 = nn.Sequential(nn.ConvTranspose3d(64+128, 64, kernel_size=3, stride=1, padding=1, bias=False),
                                             nn.ReLU())
        self.layer_decoder_9 = nn.Sequential(nn.ConvTranspose3d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
                                             nn.ReLU())
        
        self.layer_decoder_10 = nn.Sequential(nn.ConvTranspose3d(64, 1, kernel_size=1, stride=1, padding=0, bias=False),
                                              nn.ReLU())
        self.layer_decoder_11 = nn.Sequential(nn.ConvTranspose3d(1, 1, kernel_size=2, stride=2, padding=0, bias=False),
                                              nn.ReLU())
                
    def forward(self, x):
        s0 = x.size()[2]
        s1 = x.size()[3]
        s2 = x.size()[4]
        
        first_layer_encoder = self.layer_encoder_2(self.layer_encoder_1(x))
        x = self.max_pool_1(first_layer_encoder)

        second_layer_encoder = self.layer_encoder_4(self.layer_encoder_3(x))
        x = self.max_pool_2(second_layer_encoder)

        third_layer_encoder = self.layer_encoder_6(self.layer_encoder_5(x))
        x = self.max_pool_3(third_layer_encoder)
        x = self.layer_encoder_8(self.layer_encoder_7(x))


        # ---------------------------------------------
        
        x = self.layer_decoder_1(x)


        third_layer_decoder = torch.cat((x, self.center_crop(third_layer_encoder, x.size()[2:5])),1)
        del third_layer_encoder
        x = self.layer_decoder_3(self.layer_decoder_2(third_layer_decoder))
        del third_layer_decoder


        x = self.layer_decoder_4(x)

        second_layer_decoder = torch.cat((x, self.center_crop(second_layer_encoder, x.size()[2:5])),1)
        del second_layer_encoder

        x = self.layer_decoder_6(self.layer_decoder_5(second_layer_decoder))
        del second_layer_decoder

        x = self.layer_decoder_7(x)
        first_layer_decoder = torch.cat((x, self.center_crop(first_layer_encoder, x.size()[2:5])),1)
        del first_layer_encoder
        x = self.layer_decoder_9(self.layer_decoder_8(first_layer_decoder))

        del first_layer_decoder
        x = self.layer_decoder_10(x)
        x = self.layer_decoder_11(x)
        return x
    
    def center_crop(self, layer, target_sizes):
        batch_size, n_channels, dim1, dim2, dim3 = layer.size()
        dim1_c = (dim1 - target_sizes[0]) // 2
        dim2_c = (dim2 - target_sizes[1]) // 2
        dim3_c = (dim3 - target_sizes[2]) // 2
        return layer[:, :, dim1_c:dim1_c+target_sizes[0], dim2_c:dim2_c+target_sizes[1], dim3_c:dim3_c+target_sizes[2]]

In [8]:
class LitModel(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.loss = nn.BCEWithLogitsLoss()
        self.iou = self.iou_pytorch
        self.dice = torchmetrics.Dice(average='macro', num_classes=1).to(torch.device("cuda", 0))

    def iou_pytorch(self, outputs: torch.Tensor, labels: torch.Tensor):
        SMOOTH = 1e-6
        outputs = outputs.squeeze(1)  

        intersection = (outputs & labels).float().sum((1, 2)) 
        union = (outputs | labels).float().sum((1, 2))     

        iou = (intersection + SMOOTH) / (union + SMOOTH)  

        thresholded = torch.clamp(20 * (iou - 0.5), 0, 10).ceil() / 10  

        return thresholded

    def forward(self, x):
        y_pred = self.model(x)
        return y_pred
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.model.parameters(), lr=1e-5)
        return optimizer
    
    def get_metrics(self, y_pred, y, metrics_type="train"):
        y = y.type(torch.IntTensor)
        y = y.view(-1).to('cuda')
        y_pred = y_pred.view(-1).to('cuda')
        
        # acc = self.iou(y_pred, y)
        dice = self.dice(y_pred, y)

        # self.log(f'acc/{metrics_type}', acc)
        self.log(f'dice/{metrics_type}', dice)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        
        y_pred = self.forward(x)
        loss = self.loss(y_pred, y)
        # self.get_metrics(y_pred, y, "train")
        return loss
    
    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch

        y_pred = self.forward(x)
        loss = self.loss(y_pred, y)
        # self.get_metrics(y_pred, y, "val")
        return loss
    
    def test_step(self, batch, batch_idx):
        x, y = batch

        y_pred = self.forward(x)
        loss = self.loss(y_pred, y)

        self.get_metrics(y_pred, y, "test")
        return loss

In [None]:
model = UNet3D()

unet_model = LitModel(model)
trainer_unet_model = pl.Trainer(max_epochs=10, accelerator='gpu')
trainer_unet_model.fit(unet_model, trainloader)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type              | Params
--------------------------------------------
0 | model | UNet3D            | 19.1 M
1 | loss  | BCEWithLogitsLoss | 0     
2 | dice  | Dice              | 0     
--------------------------------------------
19.1 M    Trainable params
0         Non-trainable params
19.1 M    Total params
76.280    Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [None]:
torch.save(unet_model, '/content/drive/MyDrive/3dunet_2.pth')