In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from skimage import io, transform as trans
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image


data_root = 'C:\\Users\\ai_to\\Downloads\\RPMR\\RPMR'

In [2]:
def unet_conv(in_planes, out_planes):
    conv = nn.Sequential(
        nn.Conv2d(in_planes, out_planes, 3, 1, 1),
        nn.BatchNorm2d(out_planes),
        nn.ReLU(False),
        nn.Conv2d(out_planes, out_planes, 3, 1, 1),
        nn.BatchNorm2d(out_planes),
        nn.ReLU(False),
    )
    return conv

class Uresnet(nn.Module):
    def __init__(self, input_nbr = 3,label_nbr = 2):
        super(Uresnet, self).__init__()
        
        # forward
        self.downconv1 = nn.Sequential(
            nn.Conv2d(input_nbr, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(False),
        )      # No.1 long skip 
        
        self.maxpool = nn.MaxPool2d(2, 2)
        
        self.downconv2 = nn.Sequential(
            nn.Conv2d(64, 128, 1),
            nn.ReLU(False),
        )      # No1 resudual block
        
        self.downconv3 = unet_conv(128, 128) # No2 long skip
        self.maxpool = nn.MaxPool2d(2, 2)
        
        self.downconv4 = nn.Sequential(
            nn.Conv2d(128, 256, 1),
            nn.ReLU(False),
        )      # No2 resudual block
        
        self.downconv5 = unet_conv(256, 256) # No3 long skip
        self.maxpool = nn.MaxPool2d(2, 2)
        
        self.downconv6 = nn.Sequential(
            nn.Conv2d(256, 512, 1),
            nn.ReLU(False),
        )      # No3 resudual block
        
        self.downconv7 = unet_conv(512, 512) # No4 long skip

        
        self.updeconv2 = nn.Sequential(
            # nn.ConvTranspose2d(512, 256, 2, 2),
            nn.ConvTranspose2d(512, 256, kernel_size=1),
            nn.BatchNorm2d(256),
        )
           
        self.upconv3 = nn.Sequential(
            nn.Conv2d(512, 256, 1),
            nn.ReLU(False),
        )       # No6 resudual block
        self.upconv4 = unet_conv(256, 256)
        
        self.updeconv3 = nn.Sequential(
            # nn.ConvTranspose2d(256, 128, 2, 2),
            nn.ConvTranspose2d(256, 128, kernel_size=1),
            nn.BatchNorm2d(128),
        )
           
        self.upconv5 = nn.Sequential(
            nn.Conv2d(256, 128, 1),
            nn.ReLU(False),
        )       # No6 resudual block
        self.upconv6 = unet_conv(128, 128)
        self.updeconv4 = nn.Sequential(
            # nn.ConvTranspose2d(128, 64, 2, 2),
            nn.ConvTranspose2d(128, 64, kernel_size=1),
            nn.BatchNorm2d(64)
        )
        
        self.upconv7 = nn.Sequential(
            nn.Conv2d(128, 64, 1),
            nn.ReLU(False),
        
        )       # No6 resudual block
        self.upconv8 = unet_conv(64, 64)
        
        self.last = nn.Conv2d(64, label_nbr, 1) 
        
        
    def forward(self, x):
        
        # encoding
        x1 = self.downconv1(x) 
        x2 = self.maxpool(x1)     
        x3 = self.downconv2(x2)
        x4 = self.downconv3(x3)      
        x4 += x3
        x5 = self.maxpool(x4)
        
        x6 = self.downconv4(x5)
        x7 = self.downconv5(x6)
        x7 += x6
        x8 = self.maxpool(x7)
        
        x9 = self.downconv6(x8)
        x10 = self.downconv7(x9)
        x10 += x9

        y3 = nn.functional.interpolate(x10, mode='bilinear', scale_factor=2)
        y4 = self.updeconv2(y3)
        y5 = self.upconv3(torch.cat([y4, x7],1))
        y6 = self.upconv4(y5)
        y6 += y5
        
        y6 = nn.functional.interpolate(y6, mode='bilinear', scale_factor=2)
        y7 = self.updeconv3(y6)   
        y8 = self.upconv5(torch.cat([y7, x4],1))
        y9 = self.upconv6(y8)
        y9 += y8
        
        y9 = nn.functional.interpolate(y9, mode='bilinear', scale_factor=2)
        y10= self.updeconv4(y9)
        y11 = self.upconv7(torch.cat([y10, x1],1))
        y12 = self.upconv8(y11)
        y12 += y11
     
        out = self.last(y12)
        
        return out

def uresnet():
    net = Uresnet()
    return net


In [3]:
def double_conv(in_planes, out_planes):
    conv = nn.Sequential(
        nn.Conv2d(in_planes, out_planes, 3, 1, 1),
        nn.BatchNorm2d(out_planes),
        nn.ReLU(True),
        # nn.Dropout(0.2),
        nn.Conv2d(out_planes, out_planes, 3, 1, 1),
        nn.BatchNorm2d(out_planes),
        nn.ReLU(True),
    )
    return conv

class Unet(nn.Module):
    def __init__(self):
        super(Unet, self).__init__()
        
        self.downconv1 = double_conv(3, 64)
        self.maxpool = nn.MaxPool2d(2 ,2)
        
        self.downconv2 = double_conv(64, 128)
        self.downconv3 = double_conv(128, 256)
        self.downconv4 = double_conv(256, 512)

        
        self.updeconv2 = nn.ConvTranspose2d(512, 256, 2, 2)
        self.upconv2 = double_conv(512, 256)
        
        self.updeconv3 = nn.ConvTranspose2d(256, 128, 2, 2)
        self.upconv3 = double_conv(256, 128)        
        
        self.updeconv4 = nn.ConvTranspose2d(128, 64, 2, 2)
        self.upconv4 = double_conv(128, 64)
        
        self.out = nn.Conv2d(64, 6, 1)  # 6 is the number of classes need to be segment
        
        # Weight initialization
        for m in self.modules(): 
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            # kaiming
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        
    def forward(self, x):
        
        # encoder
        x1 = self.downconv1(x)
        x2 = self.maxpool(x1)
        
        x3 = self.downconv2(x2)
        x4 = self.maxpool(x3)
        
        x5 = self.downconv3(x4)
        x6 = self.maxpool(x5)
        
        x7 = self.downconv4(x6)
        
        x = self.updeconv2(x7)
        # y5 = crop_fun(x5, x)
        x = self.upconv2(torch.cat([x, x5],1))
        
        x = self.updeconv3(x)
        # y3 = crop_fun(x3, x)
        x = self.upconv3(torch.cat([x, x3],1))
        
        x = self.updeconv4(x)
        # y1 = crop_fun(x1, x)
        x = self.upconv4(torch.cat([x, x1],1))
        
        x = self.out(x)
        
        
        return x
    
def unet():
    net = Unet()
    return net

In [4]:
class LSTMModel(nn.Module):
    def __init__(self, input_size=64, hidden_size=128, num_layers=1, output_size=1):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)

        # Forward propagate LSTM
        out, _ = self.lstm(x, (h0, c0))

        # Decode the hidden state of the last time step
        out = self.fc(out)

        return out

In [5]:
class UResNetLSTM(nn.Module):
    def __init__(self, uresnet_input_nbr=3, uresnet_label_nbr=2, lstm_input_size=64, lstm_hidden_size=128, lstm_num_layers=1, lstm_output_size=3):
        super(UResNetLSTM, self).__init__()
        
        # U-ResNet model
        self.uresnet = Uresnet(input_nbr=uresnet_input_nbr, label_nbr=uresnet_label_nbr)
        
        # LSTM model
        self.lstm = LSTMModel(input_size=lstm_input_size, hidden_size=lstm_hidden_size, num_layers=lstm_num_layers, output_size=lstm_output_size)

    def forward(self, x):
        # U-ResNet forward pass
        uresnet_output = self.uresnet(x)
        
        # LSTM forward pass, using the U-ResNet output as input
        lstm_output = self.lstm(uresnet_output)
        
        #batch_size, _, _ = lstm_output.size()
        #lstm_output = lstm_output.view(batch_size, -1, self.lstm.output_size)
        lstm_output = lstm_output.unsqueeze(1).repeat(1, 10, 1) 

        return lstm_output

In [6]:
# Data preprocessing
def normalization(image):
    image = image.astype(np.float32)
    min_val = np.min(image)
    max_val = np.max(image)
    normalized_image = (image - min_val) / (max_val - min_val)
    return normalized_image

# Data augmentation
def data_augmentation(brightness=0, contrast=0):
    transform = transforms.Compose([
        transforms.RandomVerticalFlip(1),
        transforms.ColorJitter(brightness=brightness, contrast=contrast),
        transforms.ToTensor()
    ])
    return transform


class PorousMediaDataset(Dataset):
    def __init__(self, data_root, transform=None):
        self.data_root = data_root
        self.transform = transform
        self.porous_media_sequences = os.listdir(data_root)

    def __len__(self):
        
        return len(self.porous_media_sequences)

    def __getitem__(self, idx):
        sequence_folder = os.path.join(self.data_root, self.porous_media_sequences[idx])
        image_files = sorted(os.listdir(sequence_folder))
        sequence_images = []
        for image_file in image_files:
            image_path = os.path.join(sequence_folder, image_file)
            image = io.imread(image_path)
            resized_image = trans.resize(image, (200,200), mode='constant', anti_aliasing=True)
            normalized_image = normalization(resized_image)
            pil_image = Image.fromarray((normalized_image * 255).astype(np.uint8))

            sequence_images.append(pil_image)
            
        # Apply transformations if provided
        if self.transform:
            sequence_images = [self.transform(image) for image in sequence_images]

        # Convert PIL Images to torch tensors
        sequence_images = torch.stack(sequence_images)
        
        return sequence_images

In [9]:
class UResNetDataset(Dataset):
    def __init__(self, root_dir, sequence_length=10, transform=None):
        self.root_dir = root_dir
        self.image_dir = os.path.join(root_dir, 'samples')
        self.image_list = sorted(os.listdir(self.image_dir))
        self.sequence_length = sequence_length
        self.transform = transform

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        # Load image sequence
        image_sequence = []
        for i in range(self.sequence_length):
            img_name = os.path.join(self.image_dir, self.image_list[idx + i])
            image = io.imread(img_name)
            resized_image = trans.resize(image, (200,200), mode='constant', anti_aliasing=True)
            normalized_image = normalization(resized_image)
            pil_image = Image.fromarray((normalized_image * 255).astype(np.uint8))
            if self.transform:
                image = self.transform(image)
            image_sequence.append(image)

        # Convert the list of tensors to a single tensor
        image_sequence = torch.stack(image_sequence)

        return image_sequence

In [10]:
transformFlipBC = data_augmentation(contrast=0, brightness=0)

dataset = UResNetDataset(os.path.join(data_root))
#output_dir = 'C:\\Users\\ai_to\\Downloads\\RPMR\\RPMR\\Preprocessed'
#dataset.save_transformed_images(output_dir)


In [None]:
train_ratio = 0.8
valid_ratio = 0.1
test_ratio = 0.1

dataset_length = len(dataset)
train_length = int(train_ratio * dataset_length)
valid_length = int(valid_ratio * dataset_length)
test_length = dataset_length - train_length - valid_length

# Split the dataset into train, validation, and test sets
train_dataset, valid_dataset, test_dataset = random_split(dataset, [train_length, valid_length, test_length])

# Define data loaders for each split
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=4)

uresnet = Uresnet()
seg_criterion = nn.MSELoss()
optimizer = optim.Adam(uresnet.parameters(), lr=0.001)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
uresnet.to(device)

num_epochs = 50
train_losses = []
valid_losses = []

for epoch in range(num_epochs):
    uresnet.train()
    total_loss = 0.0

    for batch_idx, inputs in enumerate(train_loader):
        inputs = inputs.to(device)

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = u_resnet_model(inputs)

        # Calculate the unsupervised loss (e.g., Mean Squared Error)
        loss = seg_criterion(outputs, inputs)

        # Backpropagation and optimization
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # Print average loss for the epoch
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {total_loss/len(train_loader)}")

    # Validation loop
    uresnet.eval()
    total_valid_loss = 0.0

    with torch.no_grad():
        for inputs in valid_loader:
            inputs = inputs.to(device)
            outputs = uresnet(inputs)
            valid_loss = seg_criterion(outputs, inputs)
            total_valid_loss += valid_loss.item()

    # Print average validation loss for the epoch
    print(f"Epoch {epoch+1}/{num_epochs}, Validation Loss: {total_valid_loss/len(valid_loader)}")
    train_losses.append(total_loss/len(train_loader))
    valid_losses.append(total_valid_loss/len(valid_loader))
# Optionally, save the trained model
torch.save(uresnet.state_dict(), "u_resnet_model.pth")

plt.figure()
plt.plot(range(1, num_epochs+1), train_losses, label='Train Loss')
plt.plot(range(1, num_epochs+1), valid_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss')
plt.show()