In [None]:
#!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
#!pip install opencv-python ipywidgets imageio

In [42]:
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from Seq2Seq import Seq2Seq
from torch.utils.data import DataLoader
import os
import cv2
import time
from tqdm import tqdm
import wandb

#os.environ["TORCH_DEVICE"] = "cuda"
torch.cuda.is_available()

True

# Data Preparation

In [43]:
from pathlib import Path
from datetime import datetime
import numpy as np
p = Path(r'C:\Users\Omar\Documents\ShadowMitigation\SKIPPD')
#device = os.environ.get("TORCH_DEVICE", 'cuda')

In [44]:

dataset = {}
for path in Path(p, 'SKIPPD').rglob('*.jpg'):
    x = path.stem

    year, month, day, hour, minute = x[:4], x[4:6], x[6:8],x[8:10],x[10:12]
    
    key = f'{year}-{month}-{day}-{hour}'
    if key not in dataset.keys():
        dataset[key] = []
    
    dataset[key].append(path)
    
for key, value in dataset.items():
    dataset[key] = sorted(value)
   

In [45]:
faulty = []
correct = []
for k, v in dataset.items():
    if len(v) != 60:
        
        faulty.append(k)
    else:
        correct.append(k)

In [46]:

SKIPPD = []
for valid in correct:
    #print(valid)
    SKIPPD.append([cv2.imread(str(i), cv2.IMREAD_GRAYSCALE) for i in dataset[valid]])
print(np.array(SKIPPD).shape)

(0,)


In [None]:
# Assuming SKIPPD is a numpy array with shape (1251, 60, 64, 64)
threshold = 20000*255  # Example threshold value

# Calculate the sum of pixel values for each sequence
sequence_sums = np.sum(SKIPPD, axis=(1, 2, 3))

# Filter sequences where the sum is above the threshold
filtered_indices = np.where(sequence_sums > threshold)[0]
filtered_SKIPPD = np.array(SKIPPD)[filtered_indices]

print(f"Original shape: {np.array(SKIPPD).shape}")
print(f"Filtered shape: {filtered_SKIPPD.shape}")

In [None]:
from torch.utils.data import DataLoader


SKIPPD_SHORT = filtered_SKIPPD[:int(len(filtered_SKIPPD))]
# Shuffle Data
np.random.shuffle(SKIPPD_SHORT)


# Train, Test, Validation splits
train_data = SKIPPD_SHORT[:int(0.7*len(SKIPPD_SHORT))]       
val_data = SKIPPD_SHORT[int(0.7*len(SKIPPD_SHORT)):int(0.9*len(SKIPPD_SHORT))]       
test_data = SKIPPD_SHORT[int(0.9*len(SKIPPD_SHORT)):]     

def collate(batch, device='cuda', threshold=1000, max_attempts=15):
    # Convert batch to tensor, add channel dimension, and scale
    batch = torch.tensor(batch).unsqueeze(1) / 255.0
    batch = batch.to(device)

    for _ in range(max_attempts):
        # Randomly pick a sequence
        rand = np.random.randint(15, 59)
        input_seq = batch[:, :, rand-15:rand]
        target = batch[:, :, rand+1]
       
        # Check if the sum of the sequence exceeds the threshold
        if input_seq.sum() > threshold:

            return input_seq, target

    return input_seq, target

# Example usage (adjust as per your actual 'device' and 'batch' data)


# # Training Data Loader
train_loader = DataLoader(train_data, shuffle=True, 
                        batch_size=16, collate_fn=collate,num_workers=1)

# Validation Data Loader
val_loader = DataLoader(val_data, shuffle=True, 
                        batch_size=16, collate_fn=collate,num_workers=1)


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import cv2

for hour in faulty:
    images =[cv2.imread(str(i), cv2.IMREAD_GRAYSCALE) for i in dataset[hour]]

    for image in images:
        print(image.shape)
    # fig, axs = plt.subplots(8, 8, figsize=(12, 12))

    # for i, ax in enumerate(axs.flat):
    #     if i < len(images):
    #         ax.imshow(images[i], cmap='gray')
    #     ax.axis('off')

    # plt.tight_layout()
    # plt.show()

In [None]:
import io
import imageio
from ipywidgets import widgets, HBox

# Get a batch
input, _ = next(iter(val_loader))

# Reverse process before displaying
input = input.cpu().numpy() * 255.0     

for video in input.squeeze(1)[:32]:          # Loop over videos
    with io.BytesIO() as gif:
        imageio.mimsave(gif,video.astype(np.uint8),"GIF",fps=5)
        display(HBox([widgets.Image(value=gif.getvalue())]))

In [None]:
import matplotlib.pyplot as plt

# Get a batch
for input_seq, target in val_loader:
    # Reverse process before displaying
    input_seq = input_seq.cpu().numpy() * 255.0
    target = target.cpu().numpy() * 255.0
    
    print("Input sequence shape:", input_seq.shape)
    print("Target shape:", target.shape)
    
     #Adjust the figure size as needed
    plt.figure(figsize=(20, 8))

    # Displaying the input sequence frames
    for i in range(input_seq.shape[2]):  # input_seq.shape[2] should be 15 for the sequence length
        plt.subplot(4, 4, i+1)  # Adjust the grid to 4x4 to fit all 16 images
        plt.imshow(input_seq[0, 0, i], cmap='gray')
        plt.title(f'Input Frame {i+1}')
        plt.axis('off')

    # Displaying the target frame
    plt.subplot(4, 4, 16)  # Position for the target frame in the 4x4 grid
    plt.imshow(target[0,0], cmap='gray')
    plt.title('Target Frame')
    plt.axis('off')

    plt.tight_layout()
    plt.show()
    break  # Only process the first batch for demonstration


# Model & Training

In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau

device = 'cuda'
# The input video frames are grayscale, thus single channel
model = Seq2Seq(num_channels=1, num_kernels=64, 
kernel_size=(3, 3), padding=(1, 1), activation="relu", 
frame_size=(64, 64), num_layers=3).to(device)
lr=1e-4
optim = Adam(model.parameters(), lr=lr)

# Binary Cross Entropy, target pixel values either 0 or 1
criterion = nn.BCELoss(reduction='sum')
min_loss = float('inf')
scheduler = ReduceLROnPlateau(optim, mode='min', factor=0.1, patience=5)


In [None]:
wandb.init(
    # set the wandb project where this run will be logged
    project="ShadowMitigation",
    
    # track hyperparameters and run metadata
    config={
    "learning_rate": 1e-7,
    "architecture": "CONVLSTM - 3 Layers",
    "dataset": "0.3 SKIPPD",
    "epochs": 100,
    }
)

In [None]:
asd = torch.load(r'C:\Users\Omar\Documents\ShadowMitigation\playground\pyt\best.pth')
model.load_state_dict(asd['model_state_dict'])

In [None]:
num_epochs = 50




for epoch in range(1, num_epochs+1):

            

        
    start = time.time()
    train_loss = 0                                                 
    model.train()                                                  
    for batch_num, (input_seq, target) in enumerate(tqdm(train_loader), 1):  
        batch_start = time.time()
        output = model(input_seq)          
                                   
        loss = criterion(output.flatten(), target.flatten())       
                           
        loss.backward()                                            
        optim.step()                                               
        optim.zero_grad()  
                       

        # train_loss += loss.item()


    train_loss /= len(train_loader.dataset)                       
    val_loss = 0                                                 
    model.eval()                                                   
    with torch.no_grad():                                          
        for input, target in val_loader:                          
            output = model(input)                                   
            loss = criterion(output.flatten(), target.flatten())   
            val_loss += loss.item()                                
    val_loss /= len(val_loader.dataset)
    scheduler.step(val_loss)

    if val_loss < min_loss:
        if val_loss < min_loss or min_loss == 0:                        
            ts = str(time.time()).split(".")[0]
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optim.state_dict(),
                }
            torch.save(checkpoint, f'best_240220_2.pth')
            min_loss = val_loss
            
    wandb.log({"loss": val_loss,"val_loss_perc":val_loss/(64*64) *100})
    print("Epoch:{} Training Loss:{:.2f} Validation Loss:{:.2f}\n  Validation Loss percentage: {:.2f}    Time: {:.2f}".format(
       epoch, train_loss, val_loss, val_loss/(64*64) *100,time.time()-start))


In [15]:
checkpoint['state_dict'].keys()

odict_keys(['model.enc.enc.0.conv.conv.weight', 'model.enc.enc.0.conv.conv.bias', 'model.enc.enc.0.conv.norm.weight', 'model.enc.enc.0.conv.norm.bias', 'model.enc.enc.1.conv.conv.weight', 'model.enc.enc.1.conv.conv.bias', 'model.enc.enc.1.conv.norm.weight', 'model.enc.enc.1.conv.norm.bias', 'model.enc.enc.2.conv.conv.weight', 'model.enc.enc.2.conv.conv.bias', 'model.enc.enc.2.conv.norm.weight', 'model.enc.enc.2.conv.norm.bias', 'model.enc.enc.3.conv.conv.weight', 'model.enc.enc.3.conv.conv.bias', 'model.enc.enc.3.conv.norm.weight', 'model.enc.enc.3.conv.norm.bias', 'model.dec.dec.0.conv.conv.0.weight', 'model.dec.dec.0.conv.conv.0.bias', 'model.dec.dec.0.conv.norm.weight', 'model.dec.dec.0.conv.norm.bias', 'model.dec.dec.1.conv.conv.weight', 'model.dec.dec.1.conv.conv.bias', 'model.dec.dec.1.conv.norm.weight', 'model.dec.dec.1.conv.norm.bias', 'model.dec.dec.2.conv.conv.0.weight', 'model.dec.dec.2.conv.conv.0.bias', 'model.dec.dec.2.conv.norm.weight', 'model.dec.dec.2.conv.norm.bias', 

In [37]:
batch_size = 10
custom_training_config = {
    'pre_seq_length': 15,
    'aft_seq_length': 15,
    'total_length': 15 + 15,
    'batch_size': batch_size,
    'val_batch_size': batch_size,
    'epoch': 20,
    'lr': 0.001,
    'metrics': ['mse', 'mae'],

    'ex_name': 'custom_exp',
    'dataname': 'custom',
    'in_shape': [15,3,64,64],
}


custom_model_config = {
    # For MetaVP models, the most important hyperparameters are:
    # N_S, N_T, hid_S, hid_T, model_type
    'method': 'SimVP',
    # Users can either using a config file or directly set these hyperparameters
    # 'config_file': 'configs/custom/example_model.py',

    # Here, we directly set these parameters
    'model_type': 'gSTA',
    'N_S': 4,
    'N_T': 8,
    'hid_S': 64,
    'hid_T': 256
}

In [38]:
from openstl.utils import create_parser, default_parser

args = create_parser().parse_args([])
config = args.__dict__

# update default parameters
default_values = default_parser()
for attribute in default_values.keys():
    if config[attribute] is None:
        config[attribute] = default_values[attribute]

# update the training config
config.update(custom_training_config)
# update the model config
config.update(custom_model_config)


In [41]:
from openstl.methods import SimVP
from openstl.models import SimVP_Model

PATH = r'C:\Users\Omar\Documents\ShadowMitigation\playground\content\checkpoints\best.ckpt'

model = SimVP(steps_per_epoch=1, 
            test_mean=1, 
            test_std=1, 
            save_dir="", 
            **config)

# Load the checkpoint
checkpoint = torch.load(PATH)  # Ensure it's on the right device

# Load the adjusted state_dict into the model
model.load_state_dict(checkpoint["state_dict"])


<All keys matched successfully>

In [None]:
def collate_test(batch):

    # Last 15 frames are target
    target = np.array(batch)[:,45:]                     
    
    # Add channel dim, scale pixels between 0 and 1, send to GPU
    batch = torch.tensor(batch).unsqueeze(1)          
    batch = batch / 255.0                             
    batch = batch.to(device)                          
    return batch, target

# Test Data Loader

test_loader = DataLoader(test_data,shuffle=True, 
                         batch_size=20, collate_fn=collate_test)

# Get a batch
batch, target = next(iter(test_loader))

# Initialize output sequence
output = np.zeros(target.shape, dtype=np.uint8)

# Loop over timesteps
for timestep in tqdm(range(target.shape[1])):
  input = batch[:,:,timestep:timestep+15]   
  output[:,timestep]=(model(input).squeeze(1).cpu()>0.5)*255.0

In [None]:
from random import shuffle
shuffle(list(zip(target, output)))

for tgt, out in zip(target, output):       # Loop over samples
    
    # Write target video as gif
    with io.BytesIO() as gif:
        imageio.mimsave(gif, tgt, "GIF", fps = 2)    
        target_gif = gif.getvalue()

    # Write output video as gif
    with io.BytesIO() as gif:
        imageio.mimsave(gif, out, "GIF", fps = 2)    
        output_gif = gif.getvalue()

    display(HBox([widgets.Image(value=target_gif), 
                  widgets.Image(value=output_gif)]))