In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

import torch
import torch.nn as nn
import torch.optim as optim

import os
import json

matplotlib.rcParams['figure.figsize'] = (11.75, 8.5)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
DATA_PATH = os.path.join("data", "256window")
train_path = os.path.join(DATA_PATH, "train")
eval_path = os.path.join(DATA_PATH, "eval")
val_path = os.path.join(DATA_PATH, "val")

with open(os.path.join(train_path, "metadata.json"), "r") as file:
    train_list = json.load(file)
with open(os.path.join(eval_path, "metadata.json"), "r") as file:
    eval_list = json.load(file)
with open(os.path.join(val_path, "metadata.json"), "r") as file:
    val_list = json.load(file)
len(train_list), len(eval_list), len(val_list)

In [None]:
class FaultDataset(Dataset):
    def __init__(self, data, transform, data_path):
        self.data = data
        self.transform = transform
        self.data_path = data_path

    def __getitem__(self, index):
        record = self.data[index]
        seis = Image.open(os.path.join(self.data_path, "seis", record["data"]))
        if self.transform is not None:
            seis = self.transform(seis)
        label = Image.open(os.path.join(self.data_path, "fault", record["label"]))
        label = transforms.ToTensor()(label)
        return seis, label

    def __len__(self):
        return len(self.data)

preprocess = transforms.Compose([
    transforms.ToTensor()
])

train_dataset = FaultDataset(train_list, preprocess, train_path)
eval_dataset = FaultDataset(eval_list, preprocess, eval_path)
val_dataset = FaultDataset(val_list, preprocess, val_path)

display(transforms.ToPILImage()(train_dataset.__getitem__(3)[0]))
display(transforms.ToPILImage()(train_dataset.__getitem__(3)[1]))
print(train_dataset.__getitem__(0)[0].shape)

BATCH_SIZE = 10
train_loader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True)
eval_loader = DataLoader(eval_dataset, BATCH_SIZE)
val_loader = DataLoader(val_dataset, 1)
print(len(train_loader), len(eval_loader), len(val_loader))

In [None]:
class Unet(nn.Module):

    def __init__(self):
        super(Unet, self).__init__()
        self.encoder1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(inplace=True)
        )
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(inplace=True)
        )
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder3 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(inplace=True)
        )
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder4 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=512),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=512),
            nn.Dropout2d(p=0.5, inplace=True),
            nn.ReLU(inplace=True)
        )
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.bottleneck = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=1024),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=1024),
            nn.Dropout2d(p=0.5, inplace=True),
            nn.ReLU(inplace=True)
        )
        self.unpool4 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2)
        self.decoder4 = nn.Sequential(
            nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=512),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=512),
            nn.ReLU(inplace=True)
        )
        self.unpool3 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2)
        self.decoder3 = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(inplace=True)
        )
        self.unpool2 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2)
        self.decoder2 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(inplace=True)
        )
        self.unpool1 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2)
        self.decoder1 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=64, out_channels=2, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=2, out_channels=1, kernel_size=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))

        neck = self.bottleneck(self.pool4(enc4))

        dec4 = self.decoder4(torch.cat((self.unpool4(neck), enc4), dim=1))
        dec3 = self.decoder3(torch.cat((self.unpool3(dec4), enc3), dim=1))
        dec2 = self.decoder2(torch.cat((self.unpool2(dec3), enc2), dim=1))
        dec1 = self.decoder1(torch.cat((self.unpool1(dec2), enc1), dim=1))
        return dec1

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True, thresh=0.5):
        super(DiceLoss, self).__init__()
        self.thresh = thresh

    def forward(self, inputs, targets, smooth=1e-6):
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()

        dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)

        return 1 - dice

def iou_coeff(inputs, targets, thresh=0.5, smooth=1e-6):
    inputs, targets = inputs.flatten(), targets.flatten()
    inputs[inputs >= thresh] = np.float32(1)
    inputs[inputs < thresh] = np.float32(0)

    intersection = (inputs * targets).sum()
    union = (inputs + targets).sum() - intersection

    return (intersection + smooth) / (union + smooth)


def fda_coeff(inputs, targets, thresh=0.5, smooth=1e-6):
    # intersection / ground truth
    inputs, targets = inputs.flatten(), targets.flatten()
    inputs[inputs >= thresh] = np.float32(1)
    inputs[inputs < thresh] = np.float32(0)

    intersection = (inputs * targets).sum()

    return (intersection + smooth) / (targets.sum() + smooth)


def train_epoch(network, loader, optimizer, criterion, report_frequency=6):
    network.train()
    batch_loss = 0.
    i = 0
    for i, data in enumerate(loader, 0):
        inputs, labels = data
        inputs= inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        output = network(inputs)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()

        batch_loss += loss
        i += 1
        # if i % report_frequency == report_frequency - 1:
        #     print(f"Train Batch {i + 1:5d} || Loss: {batch_loss / (i + 1):.3f}")
    return batch_loss / i

def test_epoch(network, loader):
    network.eval()
    iou_statistic = 0
    fda_statistic = 0
    for _, data in enumerate(loader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        output = network(inputs)

        output = output.squeeze().data.cpu().numpy()
        labels = labels.squeeze().cpu().numpy()

        iou_statistic += iou_coeff(output, labels)
        fda_statistic += fda_coeff(output, labels)
    return iou_statistic / len(loader), fda_statistic / len(loader)

def fit_network(network, train_loader, val_loader, optimizer, criterion, epochs_num):
    best_acc = 1.0
    for epoch in range(1, epochs_num+1):
        epoch_loss = train_epoch(network, train_loader, optimizer, criterion)
        iou_statistic, fda_statistic = test_epoch(network, val_loader)

        info_line = f"Epoch {epoch:3d} || Loss: {epoch_loss:.4f} || IoU: {iou_statistic:.4f} FDA: {fda_statistic:.4f} "
        print(info_line)
        torch.save(network.state_dict(), f"Epoch_{epoch}.pth.gz")
    #     if epoch_loss < best_acc:
    #         best_acc = fda_statistic
    #         torch.save(network.state_dict(), "Best_ACC.pth.gz")
    # torch.save(network.state_dict(), "Latest.pth.gz")

In [None]:
network = Unet()
network = network.to(device)
criterion = DiceLoss()
optimizer = optim.Adam(network.parameters(), lr=1e-3)
fit_network(network, train_loader, eval_loader, optimizer, criterion, epochs_num=100)

In [None]:
weights_path = os.path.join('data', '256window', 'weights', 'Epoch_85.pth.gz')
model = Unet()
model.to(device)
model.load_state_dict(torch.load(weights_path, map_location=device))
model.eval()

def test_and_plot(network, loader, amount=20):
    with torch.no_grad():
        j = 0
        for i, data in enumerate(loader, 0):
            if j == amount:
                break
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            output = network(inputs)

            inputs = inputs.squeeze().cpu().numpy()
            output = output.squeeze().data.cpu().numpy()
            labels = labels.squeeze().cpu().numpy()

            output[output >= 0.5] = 1
            output[output < 0.5] = 0

            fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(20, 20))
            ax[0].imshow(labels, )
            ax[1].imshow(inputs, )
            ax[2].imshow(output)
            plt.show()
            j+=1


test_and_plot(model, val_loader)