In [1]:
import os
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
from PIL import Image
import torch
import pdb
import numpy as np
import yaml
from tqdm import tqdm

In [2]:
# Define the subjects for each split
train_subjects = [11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 23, 3, 5, 6, 7, 10]
val_subjects = [24, 25, 1, 4]
test_subjects = [22, 2, 8, 9]

# Define the background variations
background_variations = ['d1', 'd2', 'd3', 'd4']

processed_folder = './processed/'

In [3]:
#load config file and hyperparams
config = yaml.safe_load(open("config.yaml"))
LR = float(config["LR"])
batch_size = int(config["BATCH_SIZE"])
num_epochs = int(config["NUM_EPOCHS"])


In [4]:
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    # Add more augmentations if needed
])

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [6]:
class SequencesExtractor:
    def __init__(self, path, num_frames_per_subsequence=20):
        # Define the number of frames per subsequence
        self.num_frames_per_subsequence = num_frames_per_subsequence
        # List all action folders in the processed folder
        self.classes = os.listdir(path) # folders correspond to classes/labels
        self.class_to_label = {class_name: idx for idx, class_name in enumerate(self.classes)}
    
    def get_classes(self):
        return self.classes
    def create_sequences(self, subjects, background_variations):
        sequences = []
        target_arr = []
        # Iterate over each action folder
        for action_folder in self.classes:
            target = action_folder
            action_path = os.path.join(processed_folder, action_folder)

            # List all person folders in the action folder
            person_folders_actual = os.listdir(action_path)
            # Filter videos based on subjects and background variations
            person_folders_target = [f'person{subject:02d}_{action_folder}_{bg}' 
                           for subject in subjects 
                           for bg in background_variations]
            person_folders = set(person_folders_actual) & set(person_folders_target) # it can be that some background variations (or smth else) is missing. 
            # Iterate over each person folder
            for person_folder in person_folders:
                person_path = os.path.join(action_path, person_folder)
                # List all image files in the person folder
                image_files = sorted(os.listdir(person_path))

                # Split the image files into subsequences
                num_frames = len(image_files)
                num_subsequences = num_frames // self.num_frames_per_subsequence

                for i in range(num_subsequences):
                    start_index = i * self.num_frames_per_subsequence
                    end_index = start_index + self.num_frames_per_subsequence

                    # Load and process the frames in the subsequence
                    subsequence_frames = []
                    for j in range(start_index, end_index):
                        try:
                            image_path = os.path.join(person_path, image_files[j])
                            frame = Image.open(image_path).convert('RGB')
                            # Apply any desired spatial augmentations to the frame
                            frame = transform(frame)
                            subsequence_frames.append(frame)
                        except:
                            print("Tried to read wrong file. Continuing")
                            continue

                  #  subsequence_frames = torch.stack(subsequence_frames, dim=0)
                    # Perform further processing on the subsequence
                    # Check if subsequence has the expected number of frames
                    if len(subsequence_frames) == self.num_frames_per_subsequence:
                        # Process the subsequence (e.g., feed it to a model for action classification)
                        subsequence_frames = torch.stack(subsequence_frames, dim=0)
                        # Perform further processing on the subsequence

                        target_arr.append(self.class_to_label[target])
                        sequences.append(subsequence_frames)
                    else:
                        print(f"Skipping subsequence due to incorrect number of frames: {len(subsequence_frames)}")
                        # target_arr.append(self.class_to_label[target])
                        # sequences.append(subsequence_frames)
        return np.array(sequences), np.array(target_arr)

In [7]:
sequencesExtractor = SequencesExtractor(path=processed_folder, num_frames_per_subsequence=20)
train_sequences, train_target_arr = sequencesExtractor.create_sequences(train_subjects, background_variations)
test_sequences, test_target_arr = sequencesExtractor.create_sequences(test_subjects, background_variations)
val_sequences, val_target_arr = sequencesExtractor.create_sequences(val_subjects, background_variations)

Tried to read wrong file. Continuing
Skipping subsequence due to incorrect number of frames: 19
Tried to read wrong file. Continuing
Skipping subsequence due to incorrect number of frames: 19
Tried to read wrong file. Continuing
Skipping subsequence due to incorrect number of frames: 19
Tried to read wrong file. Continuing
Skipping subsequence due to incorrect number of frames: 19
Tried to read wrong file. Continuing
Skipping subsequence due to incorrect number of frames: 19
Tried to read wrong file. Continuing
Skipping subsequence due to incorrect number of frames: 19
Tried to read wrong file. Continuing
Skipping subsequence due to incorrect number of frames: 19
Tried to read wrong file. Continuing
Skipping subsequence due to incorrect number of frames: 19
Tried to read wrong file. Continuing
Skipping subsequence due to incorrect number of frames: 19


In [8]:
print(f"Training sequences length: {len(train_sequences)}")
print(f"Validation sequences length: {len(val_sequences)}")
print(f"Test sequences length: {len(test_sequences)}")

Training sequences length: 9841
Validation sequences length: 2140
Test sequences length: 2252


In [9]:
from torch.utils.data import Dataset, DataLoader

In [10]:
class KTHDataset(Dataset):
    def __init__(self, sequences, labels):
        self.sequences = sequences
        self.labels = labels

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        label = self.labels[idx]

        return sequence, label

In [11]:
train_dataset = KTHDataset(train_sequences, train_target_arr)
test_dataset = KTHDataset(test_sequences, test_target_arr)
val_dataset = KTHDataset(val_sequences, val_target_arr)

In [12]:
batch_size = 32
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [13]:
if(not os.path.exists("models")):
    os.makedirs("models")

In [14]:
class ConvBlock(nn.Module):
    """
    Encapuslation of a convolutional block (conv + activation + pooling)
    """
    def __init__(self, in_ch, out_ch, k_size, pool=False, dropout_prob = 0.2, mxpool_stride=1):
        super(ConvBlock, self).__init__()
        layers = []
        layers.append(nn.Conv2d(in_ch, out_ch, k_size))
        layers.append(nn.ReLU())
        if(pool):
            layers.append(nn.MaxPool2d(kernel_size=2))
       # layers.append(nn.Dropout(p=dropout_prob))
        self.module = nn.Sequential(*layers)
    def forward(self, x):
        return(self.module(x))


In [15]:
class ConvRecurrentClassifier(nn.Module):
    def __init__(self, input_channels, hidden_size, num_classes,num_layers = 20, mode="zeros"):
        assert mode in ["zeros", "random"]
        super(ConvRecurrentClassifier, self).__init__()
        self.mode = mode
        self.num_layers = num_layers
        self.hidden_dim = hidden_size
        #Convolutional Encoder
        # self.conv_encoder = nn.Sequential(
        #     ConvBlock(3, 16, 3, pool=False),
        #     ConvBlock(16, 32, 3, pool=True),
        #     ConvBlock(32, 64, 3, pool=False),
        #     ConvBlock(64, 128, 3, pool=True)
        # )
        
        self.conv_encoder = nn.Sequential(
            # nn.Conv2d(input_channels, 64, kernel_size=3, padding=1),
            # nn.ReLU(inplace=True),
            # nn.MaxPool2d(kernel_size=2, stride=2),
            # nn.Conv2d(64, 128, kernel_size=3, padding=1),
            # nn.ReLU(inplace=True),
            # nn.MaxPool2d(kernel_size=2, stride=2),
            # Add more convolutional layers if needed
            ConvBlock(3, 16, 3, pool=False),
            ConvBlock(16, 32, 3, pool=True),
            ConvBlock(32, 64, 3, pool=False),
            ConvBlock(64, 128, 3, pool=True),
            ConvBlock(128, 256, 3, pool=False),
            ConvBlock(256, 512, 3, pool=True),
            ConvBlock(512, 1024, 3, pool=True),
        )
        
        #Recurrent Module
        #self.lstm = nn.LSTM(input_size=256 * 315 * 11, hidden_size=hidden_size, batch_first=True)
        # nn.LSTM
        # LSTM model       
        lstms = []
        for i in range(num_layers):
            in_size = 78848 if i == 0 else self.hidden_dim
            lstms.append( nn.LSTMCell(input_size=in_size, hidden_size=self.hidden_dim) )
        self.lstm = nn.ModuleList(lstms)
        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, num_classes)
        )
        
    def forward(self, x):
        b_size, seq_length, n_channels, width, height = x.shape
        # Reshape input to (batch_size, channels, sequence_length*height, width)
        x = x.view(b_size, n_channels, seq_length*width, height)
        
        h, c = self.init_state(b_size=b_size, device=device)
         # Convolutional Encoder
        x = self.conv_encoder(x)
        
        # Reshape for LSTM
        embeddings = x.view(x.size(0), -1, x.size(1) * x.size(2) * x.size(3))
        # Recurrent Module
        #out, (h_out, c_out) = self.lstm(x, (h,c))
        # iterating over sequence length
        lstm_out = []
        for i in range(embeddings.shape[1]):
            lstm_input = embeddings[:, i, :]
            # iterating over LSTM Cells
            for j, lstm_cell in enumerate(self.lstm):
                h[j], c[j] = lstm_cell(lstm_input, (h[j], c[j]))
                lstm_input = h[j]
            lstm_out.append(lstm_input)
        lstm_out = torch.stack(lstm_out, dim=1)
        # Take the output from the last time step
        out = lstm_out[:, -1, :]
        
        # Classifier
        out = self.classifier(out)
        
        return out
    def init_state(self, b_size, device):
        """ Initializing hidden and cell state """
        if(self.mode == "zeros"):
            h = [torch.zeros(b_size, self.hidden_dim).to(device) for _ in range(self.num_layers)]
            c = [torch.zeros(b_size, self.hidden_dim).to(device) for _ in range(self.num_layers)]
        elif(self.mode == "random"):
            h = [torch.zeros(b_size, self.hidden_dim).to(device) for _ in range(self.num_layers)]
            c = [torch.zeros(b_size, self.hidden_dim).to(device) for _ in range(self.num_layers)]
        return h, c
        # h = h.to(device)
        # c = c.to(device)
        # return h, c

In [16]:
model = ConvRecurrentClassifier(input_channels=3, hidden_size=128, num_classes=len(sequencesExtractor.get_classes()))

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
criterion = nn.CrossEntropyLoss()  # Note, that this already includes a Softmax!
optimizer = torch.optim.AdamW(model.parameters(), lr=LR) #adamW was used in the paper
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(train_dataloader), epochs=num_epochs)
model = model.to(device)

cuda


In [18]:
torch.cuda.empty_cache()

In [19]:
@torch.no_grad()
def eval_model(model):
    """ Computing model accuracy """
    correct = 0
    total = 0
    loss_list = []
    
    for sequences, labels in val_dataloader:
        sequences = sequences.to(device)
        labels = labels.to(device)
        
        # Forward pass only to get logits/output
        outputs = model(sequences)
                 
        loss = criterion(outputs, labels)
        loss_list.append(loss.item())
            
        # Get predictions from the maximum value
        preds = torch.argmax(outputs, dim=1)
        correct += len( torch.where(preds==labels)[0] )
        total += len(labels)
                 
    # Total correct predictions and loss
    accuracy = correct / total * 100
    loss = np.mean(loss_list)
    return accuracy, loss

In [20]:
#SAMPLE Training
loss_hist = []
train_acc_hist = []
valid_acc_hist = []
valid_loss_hist = []
best_loss = 100
for epoch in range(num_epochs):
    loss_list = []
    acc_list = []
    progress_bar = tqdm(enumerate(train_dataloader), total=len(train_dataloader))
    for i, (sequences, labels) in progress_bar:
        sequences = sequences.to(device)
        labels = labels.to(device)
        
        # Clear gradients w.r.t. parameters
        optimizer.zero_grad()
        
        # Forward pass to get output/logits
        outputs = model(sequences)
         
        # Calculate Loss: softmax --> cross entropy loss
        loss = criterion(outputs, labels)
        loss_list.append(loss.item())
        
        with torch.no_grad():
            predicted = outputs.argmax(dim=-1)
            correct = (predicted == labels).sum().item()
            accuracy = correct/labels.shape[0] * 100
        acc_list.append(accuracy)
        # Getting gradients w.r.t. parameters
        loss.backward()
       # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0, norm_type=2.0)
        # Updating parameters
        optimizer.step()
       # scheduler.step()
        
        progress_bar.set_description(f"Epoch {epoch+1} Iter {i+1}: loss {loss.item():.5f}. ")
    
    loss_hist.append(np.mean(loss_list))
    train_acc_hist.append(np.mean(acc_list))
    val_accuracy, valid_loss = eval_model(model)
    if valid_loss < best_loss:
            best_loss = valid_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, f"models/checkpoint_epoch_{epoch}.pth")
    print(f"Val accuracy at epoch {epoch}: {round(val_accuracy, 2)}%")
    valid_loss_hist.append(valid_loss)
    valid_acc_hist.append(val_accuracy)

Epoch 1 Iter 308: loss 1.80868. : 100%|███████████████████████████████████████████████| 308/308 [00:30<00:00,  9.97it/s]


Val accuracy at epoch 0: 23.18%


Epoch 2 Iter 308: loss 1.71787. : 100%|███████████████████████████████████████████████| 308/308 [00:31<00:00,  9.87it/s]


Val accuracy at epoch 1: 23.18%


Epoch 3 Iter 308: loss 1.72162. : 100%|███████████████████████████████████████████████| 308/308 [00:31<00:00,  9.81it/s]


Val accuracy at epoch 2: 23.18%


Epoch 4 Iter 308: loss 1.74556. : 100%|███████████████████████████████████████████████| 308/308 [00:31<00:00,  9.75it/s]


Val accuracy at epoch 3: 23.18%


Epoch 5 Iter 308: loss 1.74241. : 100%|███████████████████████████████████████████████| 308/308 [00:31<00:00,  9.65it/s]


Val accuracy at epoch 4: 23.18%


Epoch 6 Iter 308: loss 1.81384. : 100%|███████████████████████████████████████████████| 308/308 [00:32<00:00,  9.58it/s]


Val accuracy at epoch 5: 23.18%


Epoch 7 Iter 308: loss 1.71015. : 100%|███████████████████████████████████████████████| 308/308 [00:32<00:00,  9.55it/s]


Val accuracy at epoch 6: 23.18%


Epoch 8 Iter 308: loss 1.77283. : 100%|███████████████████████████████████████████████| 308/308 [00:32<00:00,  9.53it/s]


Val accuracy at epoch 7: 23.18%


Epoch 9 Iter 308: loss 1.76986. : 100%|███████████████████████████████████████████████| 308/308 [00:32<00:00,  9.51it/s]


Val accuracy at epoch 8: 23.18%


Epoch 10 Iter 308: loss 1.85641. : 100%|██████████████████████████████████████████████| 308/308 [00:32<00:00,  9.51it/s]


Val accuracy at epoch 9: 23.18%


Epoch 11 Iter 308: loss 1.71220. : 100%|██████████████████████████████████████████████| 308/308 [00:32<00:00,  9.50it/s]


Val accuracy at epoch 10: 23.18%


Epoch 12 Iter 308: loss 1.80133. : 100%|██████████████████████████████████████████████| 308/308 [00:32<00:00,  9.48it/s]


Val accuracy at epoch 11: 23.18%


Epoch 13 Iter 308: loss 1.78872. : 100%|██████████████████████████████████████████████| 308/308 [00:32<00:00,  9.49it/s]


Val accuracy at epoch 12: 23.18%


Epoch 14 Iter 285: loss 1.79094. :  93%|██████████████████████████████████████████▌   | 285/308 [00:30<00:02,  9.46it/s]


KeyboardInterrupt: 