DL Project Prototype

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
os.chdir('/content/drive/My Drive')

In [None]:
!unzip Dataset_Student_V2

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: Dataset_Student/unlabeled/video_14093/image_19.png  
  inflating: Dataset_Student/unlabeled/video_14093/image_20.png  
  inflating: Dataset_Student/unlabeled/video_14093/image_21.png  
   creating: Dataset_Student/unlabeled/video_14099/
  inflating: Dataset_Student/unlabeled/video_14099/image_0.png  
  inflating: Dataset_Student/unlabeled/video_14099/image_1.png  
  inflating: Dataset_Student/unlabeled/video_14099/image_2.png  
  inflating: Dataset_Student/unlabeled/video_14099/image_3.png  
  inflating: Dataset_Student/unlabeled/video_14099/image_4.png  
  inflating: Dataset_Student/unlabeled/video_14099/image_5.png  
  inflating: Dataset_Student/unlabeled/video_14099/image_6.png  
  inflating: Dataset_Student/unlabeled/video_14099/image_7.png  
  inflating: Dataset_Student/unlabeled/video_14099/image_8.png  
  inflating: Dataset_Student/unlabeled/video_14099/image_9.png  
  inflating: Dataset_Student/unlabe

In [None]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

In [None]:
class VideoDataset(Dataset):
    def __init__(self, base_path, dataset_type='train', transform=None):
        self.base_path = base_path
        self.dataset_type = dataset_type
        self.transform = transform
        self.samples = self._load_samples()

    def _load_samples(self):
        samples = []
        folder_path = os.path.join(self.base_path, self.dataset_type)
        for video_folder in os.listdir(folder_path):
            video_path = os.path.join(folder_path, video_folder)
            image_count = 21 if self.dataset_type == 'unlabeled' else 22
            images = [os.path.join(video_path, f'image_{i}.png') for i in range(0, image_count)]
            mask_path = os.path.join(video_path, 'mask.npy') if self.dataset_type != 'unlabeled' else None
            samples.append((images, mask_path))
        return samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_paths, mask_path = self.samples[idx]
        images = [Image.open(path) for path in image_paths]

        if self.transform is not None:
            images = [self.transform(image) for image in images]

        images = torch.stack(images)

        if mask_path:
            mask = np.load(mask_path)
            mask = torch.tensor(mask, dtype=torch.long)
        else:
            mask = torch.tensor([])

        return images, mask

In [None]:
# Base path for datasets
base_path = '/content/drive/My Drive/Dataset_Student'  # Adjust this path

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    # Add any other transformations here
])

In [None]:
train_dataset = VideoDataset(base_path, dataset_type='train', transform=transform)
val_dataset = VideoDataset(base_path, dataset_type='val', transform=transform)
unlabeled_dataset = VideoDataset(base_path, dataset_type='unlabeled', transform=transform)

# Create DataLoaders for each dataset
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=True)

In [None]:
import torch.nn as nn
import torchvision.models as models

In [None]:
class ConvLSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size, bias=True):
        """
        Initialize ConvLSTM cell.
        """
        super(ConvLSTMCell, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.kernel_size = kernel_size
        self.padding = kernel_size // 2
        self.bias = bias

        self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state

        combined = torch.cat([input_tensor, h_cur], dim=1)  # concatenate along channel axis
        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)

        return h_next, c_next

    def init_hidden(self, batch_size, image_size):
        height, width = image_size
        return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
                torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))

In [None]:
class ConvLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size, num_layers, batch_first=False):
        super(ConvLSTM, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.num_layers = num_layers
        self.batch_first = batch_first

        self.layers = nn.ModuleList()

        for i in range(self.num_layers):
            self.layers.append(ConvLSTMCell(input_dim=self.input_dim,
                                            hidden_dim=self.hidden_dim,
                                            kernel_size=self.kernel_size))

    def forward(self, x, hidden_state=None):
        b, seq_len, _, h, w = x.size()  # Assuming x is of shape (batch, sequence, channels, height, width)

        if hidden_state is None:
            hidden_state = self._init_hidden(batch_size=b, image_size=(h, w))

        layer_output_list = []
        last_state_list = []

        for layer_idx in range(self.num_layers):
            h, c = hidden_state[layer_idx]
            output_inner = []
            for t in range(seq_len):
                h, c = self.layers[layer_idx](x[:, t, :, :, :], (h, c))
                output_inner.append(h)

            layer_output = torch.stack(output_inner, dim=1)
            layer_output_list.append(layer_output)
            last_state_list.append((h, c))

        layer_output_list = layer_output_list[-1]  # We only need the output of the last layer
        last_state_list = last_state_list[-1]

        return layer_output_list, last_state_list

    def _init_hidden(self, batch_size, image_size):
        init_states = []
        for i in range(self.num_layers):
            init_states.append(self.layers[i].init_hidden(batch_size, image_size))
        return init_states

In [None]:
from torchvision.models.segmentation import fcn_resnet50

# Instantiate U-Net
unet_model = fcn_resnet50(pretrained=False)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:01<00:00, 74.2MB/s]


In [None]:
class FramePredictionAndSegmentation(nn.Module):
    def __init__(self):
        super(FramePredictionAndSegmentation, self).__init__()
        self.conv_lstm = ConvLSTM(input_dim=3, hidden_dim=64, kernel_size=3, num_layers=2)
        self.segmentation_net = unet_model

    def forward(self, x):
        # x is a batch of sequences of 11 frames
        # Predict the 22nd frame
        predicted_frames, _ = self.conv_lstm(x)
        predicted_frame = predicted_frames[:, -1, :, :, :]  # Get the last frame

        # Perform segmentation on the predicted frame
        segmentation_mask = self.segmentation_net(predicted_frame)['out']

        return predicted_frame, segmentation_mask

# Instantiate the model
model = FramePredictionAndSegmentation()


In [None]:
# Assuming you have a model, loss functions, and optimizer set up
frame_prediction_criterion = nn.MSELoss()
segmentation_criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 1  # Number of training epochs

for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    for i, (sequences, masks) in enumerate(train_loader):
        # Reorder the input to match (batch_size, sequence_length, channels, height, width)
        sequences = sequences.permute(0, 1, 4, 2, 3)

        # Forward pass
        predicted_frame, predicted_mask = model(sequences)

        # The ground truth 22nd frame is the last frame in the sequences
        true_22nd_frame = sequences[:, -1, :, :, :]  # Get the last frame from the sequences

        frame_loss = frame_prediction_criterion(predicted_frame, true_22nd_frame)
        segmentation_loss = segmentation_criterion(predicted_mask, masks)

        # Compute total loss
        total_loss = frame_loss + segmentation_loss

        # Backward pass and optimization
        optimizer.zero_grad()  # Zero the gradient buffers
        total_loss.backward()  # Backpropagation
        optimizer.step()       # Update weights

        # Optional: Print out loss, accuracy, etc.
        if (i + 1) % 10 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Frame Loss: {frame_loss.item()}, Segmentation Loss: {segmentation_loss.item()}')

    # Optional: Perform validation after each epoch
    # ...

# Save the trained model
# torch.save(model.state_dict(), 'model.pth')


RuntimeError: ignored