In [None]:
# Load Libraries
import os
import glob
from torchvision.datasets.folder import default_loader
from torchvision.transforms import ToTensor, Compose
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, TensorDataset, random_split
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR 
import torch.nn.functional as F
from tqdm import tqdm

In [None]:
# Data preprocessing
# Data preprocessing
augment_transforms = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),  # Convert 1-channel grayscale to 3-channel
    transforms.Resize((224, 224)),                # Resize to fit VGG19 input dimensions
    transforms.ToTensor()
])

root_dir = './UCSD_Anomaly_Dataset.v1p2/UCSDped1/Train/'

def skipping_stride(images_dir, window_size=10, stride=1):
    images = sorted(glob.glob(os.path.join(images_dir, '*.tif')))
    sequences = []
    # if we are predicting LSTM images then should not multiply 
    for start in range(0, len(images), stride * window_size):
        window = images[start: start + stride * window_size:stride]
        sequence = [augment_transforms(default_loader(f)) for f in window]
        sequence_tensor = torch.stack(sequence, dim=0)
        sequences.append(sequence_tensor)
    return sequences

data = []

for subfolder in tqdm(sorted(os.listdir(root_dir)), desc="Processing Subfolders"):
    folder_path = os.path.join(root_dir, subfolder)
    for stride_value in [1,2]:
        sequences = skipping_stride(folder_path, stride=stride_value)
        if sequences:
            data.extend(sequences)


data_tensor = torch.stack(data, dim=0)
print(data_tensor.shape)

# Create the dataset from the tensor
dataset = TensorDataset(data_tensor)

# Split the dataset into training and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

In [None]:
def plot_image_tensor(tensor_image):
    # CxHxW to HxWxC or getting invalid dimensions
    tensor_image = tensor_image.permute(1, 2, 0)  
    plt.imshow(tensor_image.numpy())
    plt.title('Single Image')
    plt.axis('off') 
    plt.show()

# Testing the dataset
batch_size = 1
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Create an iterator from the DataLoader
train_loader_iter = iter(train_loader)

# Fetch the next batch
batch = next(train_loader_iter) 

print(batch[0].shape)

plot_image_tensor(batch[0][0][0])  # Plot the first image from the batch

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

class ConvLSTMCell(nn.Module):
    """
    Convolutional LSTM Cell
    """
    def __init__(self, input_channels, hidden_channels, kernel_size):
        super(ConvLSTMCell, self).__init__()
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        self.padding = kernel_size // 2
        self.gates = nn.Conv2d(in_channels=input_channels + hidden_channels,
                               out_channels=4 * hidden_channels,  # For input, forget, cell, and output gates
                               kernel_size=kernel_size,
                               padding=self.padding)

    def forward(self, input_tensor, hidden_state):
        h_cur, c_cur = hidden_state
        combined = torch.cat([input_tensor, h_cur], dim=1)
        #print(f"combined shape: {[input_tensor.shape, h_cur.shape]}")
        gates = self.gates(combined)
        input_gate, forget_gate, cell_gate, output_gate = gates.chunk(4, 1)

        input_gate = torch.sigmoid(input_gate)
        forget_gate = torch.sigmoid(forget_gate)
        cell_gate = torch.tanh(cell_gate)
        output_gate = torch.sigmoid(output_gate)

        c_next = forget_gate * c_cur + input_gate * cell_gate
        h_next = output_gate * torch.tanh(c_next)
        return h_next, c_next

    def init_hidden(self, batch_size, image_size):
        height, width = image_size
        return (torch.zeros(batch_size, self.hidden_channels, height, width, device=self.gates.weight.device),
                torch.zeros(batch_size, self.hidden_channels, height, width, device=self.gates.weight.device))

class ConvLSTM_VGG19(nn.Module):
    """
    ConvLSTM model with VGG19-based architecture
    """
    def __init__(self):
        super(ConvLSTM_VGG19, self).__init__()

        # Load pre-trained VGG19 model
        vgg19_features = models.vgg19(pretrained=True).features

        # Using VGG19 as encoder (feature extractor)
        self.encoder = nn.Sequential(*list(vgg19_features.children()))
        
        
        # ConvLSTM layers
        self.convlstm1 = ConvLSTMCell(input_channels=512, hidden_channels=64, kernel_size=3)
        self.convlstm2 = ConvLSTMCell(input_channels=64, hidden_channels=32, kernel_size=3)
        self.convlstm3 = ConvLSTMCell(input_channels=32, hidden_channels=64, kernel_size=3)

        # Spatial Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=10, stride=3, padding=2),
            nn.BatchNorm2d(64),
            nn.Dropout(0.5),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=64, out_channels=128, kernel_size=9, stride=3, padding=2),
            nn.BatchNorm2d(128),
            nn.Dropout(0.5),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=128, out_channels=3, kernel_size=9,stride=3, padding=2)
        )

    def forward(self, x):
        b, seq_len, _, h, w = x.size()
        h1, c1 = self.convlstm1.init_hidden(b, (7, 7))
        h2, c2 = self.convlstm2.init_hidden(b, (7, 7))
        h3, c3 = self.convlstm3.init_hidden(b, (7, 7))

        output_sequence = []
        for t in range(seq_len):
            xt = self.encoder(x[:, t])
            #print("******:" , xt.shape)
            h1, c1 = self.convlstm1(xt, (h1, c1))
            h2, c2 = self.convlstm2(h1, (h2, c2))
            h3, c3 = self.convlstm3(h2, (h3, c3))
            xt = self.decoder(h3)
            xt = torch.sigmoid(xt)
            output_sequence.append(xt.unsqueeze(1))

        output_sequence = torch.cat(output_sequence, dim=1)
        return output_sequence

# Example usage
model = ConvLSTM_VGG19()
image = torch.randn(1, 10, 3, 224, 224)
output = model(image)
print(output.size())


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

batch_size = 10
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

print(f"train_loader.dataset: {len(train_loader.dataset)}")
print(f"val_loader.dataset: {len(val_loader.dataset)}")

model = ConvLSTM_VGG19()  # Assuming ConvLSTM is defined elsewhere
model = nn.DataParallel(model)
model.to(device)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001, eps=1e-6, weight_decay=1e-5)

# Early stopping parameters
patience = 15
best_val_loss = float('inf')
patience_counter = 0
best_model_path = 'VGG19_ConvLSTM.pth'

# To store loss statistics
train_losses = []
val_losses = []

# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    num_batches = 0
    for (images,) in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}', unit='batch'):
        images = images.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, images)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        num_batches += 1

    average_loss = total_loss / num_batches
    train_losses.append(average_loss)
    print(f'Epoch [{epoch + 1}/{num_epochs}] Average Training Loss: {average_loss:.4f}')
    
    # Validation phase
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for (images,) in val_loader:
            images = images.to(device)
            outputs = model(images)
            val_loss += criterion(outputs, images).item()

    val_loss /= len(val_loader)
    val_losses.append(val_loss)
    print(f'Validation Loss after Epoch {epoch + 1}: {val_loss:.4f}')

    # Early stopping check and model saving
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        torch.save(model.state_dict(), best_model_path)
        print(f'Model saved at epoch {epoch + 1}')
    else:
        patience_counter += 1

    if patience_counter >= patience:
        print(f'Early stopping triggered after epoch {epoch + 1}')
        break

# Plotting training and validation losses
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()
