In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split, Subset
from torchvision.models.video import r2plus1d_18
from torchvision.io import read_video
import torch.nn.functional as F
from tqdm import tqdm
from sklearn.metrics import confusion_matrix, accuracy_score
from sklearn.model_selection import train_test_split
import numpy as np
import itertools
import matplotlib.pyplot as plt
import os

In [None]:
class VideoDataset(Dataset):
    def __init__(self, root_dir, frame_count, resize, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.data = []
        self.labels = []
        self.classes = [d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
        self.classes.sort()
        self.class_to_idx = {self.classes[i]: i for i in range(len(self.classes))}
        self.frame_count = frame_count
        self.resize = resize

        for idx, action_class in enumerate(self.classes):
            class_path = os.path.join(root_dir, action_class)
            for video_file in os.listdir(class_path):
                if video_file.endswith(('.mp4', '.avi')) and not video_file.startswith('.'):
                    self.data.append(os.path.join(class_path, video_file))
                    self.labels.append(idx)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        video_path = self.data[idx]
        label = self.labels[idx]
        video, _, _ = read_video(video_path)
    
        # Process video
        video = self.process_video(video)
    
        if self.transform:
            video = self.transform(video)
    
        return video, label

    import torch

    def process_video(self, video):
        # video is a tensor of shape [frames, height, width, channels]
        video = video.float()
    
        resized_frames = []
        for i in range(min(self.frame_count, video.shape[0])):
            frame = video[i]
    
            # Convert to RGB if grayscale
            if frame.size(2) == 1:
                frame = frame.repeat(1, 1, 3)
    
            # Permute to [channels, height, width]
            frame = frame.permute(2, 0, 1)
    
            # Resize frame
            frame = F.interpolate(frame.unsqueeze(0), size=self.resize, mode='bilinear', align_corners=False).squeeze(0)
    
            resized_frames.append(frame)
    
        # Pad if necessary
        while len(resized_frames) < self.frame_count:
            resized_frames.append(resized_frames[-1])
    
        # Stack and permute to [channels, frames, height, width]
        video_tensor = torch.stack(resized_frames, dim=0).permute(1, 0, 2, 3)
    
        return video_tensor


In [None]:
# Initialize dataset
root_dir = 'merged'  
frame_count = 16
resize = (224, 224)
dataset = VideoDataset(root_dir, frame_count, resize)

train_ratio = 0.8
labels = dataset.labels

# Perform a stratified split
train_indices, test_indices = train_test_split(
    range(len(labels)), 
    test_size=1 - train_ratio, 
    stratify=labels
)

# Create Subset objects for train and test datasets
train_dataset = Subset(dataset, train_indices)
test_dataset = Subset(dataset, test_indices)

# Create data loaders for train and test sets
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)
# train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
# test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [None]:
PATH = 'r2plus1d_model_b2.pth'
# PATH = 'r2plus1d_model_b16.pth'
model = r2plus1d_18()
num_classes = len(dataset.classes)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model.load_state_dict(torch.load(PATH))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(model)

In [None]:
# Training loop
num_epochs = 5

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    train_correct = 0
    total_train = 0
    train_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
    for videos, labels in train_bar:
        videos = videos.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(videos)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * videos.size(0)
        _, predicted = torch.max(outputs.data, 1)
        train_correct += (predicted == labels).sum().item()
        total_train += labels.size(0)

        train_bar.set_postfix(loss=train_loss/total_train, accuracy=100.0*train_correct/total_train)

torch.save(model.state_dict(), 'r2plus1d_model_b2.pth')

In [None]:
model.eval()

# Initialize the counters for top-1, top-3, and top-5 accuracy
top1_correct = 0
top3_correct = 0
top5_correct = 0
total = 0

with torch.no_grad():
    for videos, labels in test_loader:
        videos = videos.to(device)
        labels = labels.to(device)
        outputs = model(videos)
        
        # Get the top 5 predictions from the outputs
        _, top5 = outputs.topk(5, 1, True, True)
        top5 = top5.t()
        
        # Increment the total counter
        total += labels.size(0)
        
        # Check if the true label is within the top 5 predictions
        correct = top5.eq(labels.view(1, -1).expand_as(top5))
        
        # top-1 accuracy is the sum of the first row of correct
        top1_correct += correct[0].float().sum(0, keepdim=True).item()
        # top-3 accuracy is the sum of the first 3 rows of correct
        top3_correct += correct[:3].reshape(-1).float().sum(0, keepdim=True).item()
        # top-5 accuracy is the sum of all 5 rows of correct
        top5_correct += correct.reshape(-1).float().sum(0, keepdim=True).item()

# Calculate the top-1, top-3, and top-5 accuracies
acc1 = top1_correct / total
acc3 = top3_correct / total
acc5 = top5_correct / total
print(f"Top-1 Accuracy: {acc1 * 100:.2f}%")
print(f"Top-3 Accuracy: {acc3 * 100:.2f}%")
print(f"Top-5 Accuracy: {acc5 * 100:.2f}%")

In [None]:
import torch
import numpy as np
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import itertools

# Set model to evaluation mode
model.eval()

# Initialize lists to store true and predicted labels
y_true = []
y_pred = []

# Initialize top-k correct prediction counters
topk_correct = {k: np.zeros(len(dataset.classes)) for k in [1, 3, 5]}
topk_total = np.zeros(len(dataset.classes))

# Initialize confusion matrix
conf_matrix = np.zeros((len(dataset.classes), len(dataset.classes)), dtype=int)

# Process each batch in test_loader
with torch.no_grad():
    for videos, labels in test_loader:
        videos = videos.to(device)
        labels = labels.to(device)
        outputs = model(videos)
        
        # Get top-5 predictions
        _, top5 = outputs.topk(5, 1, True, True)
        
        for i, label in enumerate(labels):
            y_true.append(label.item())
            y_pred.append(top5[i, 0].item())  # Add top-1 prediction for the confusion matrix
            topk_total[label.item()] += 1
            
            # Check if the true label is in the top k predictions
            true_label = label.item()
            pred_labels = top5[i].tolist()
            for k in [1, 3, 5]:
                if true_label in pred_labels[:k]:
                    topk_correct[k][true_label] += 1

# Calculate the confusion matrix
conf_matrix = confusion_matrix(y_true, y_pred, labels=list(range(len(dataset.classes))))

def plot_confusion_matrix(cm, classes, title='Confusion Matrix', cmap=plt.cm.Blues):
    plt.figure(figsize=(10, 8)) 
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=90)
    plt.yticks(tick_marks, classes)
    plt.tight_layout()

    fmt = 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.show()

plot_confusion_matrix(conf_matrix, dataset.classes)

# Calculate and print Top-1, Top-3, and Top-5 accuracies for each class
for i, class_name in enumerate(dataset.classes):
    print(f"\nAccuracy for class {class_name}:")
    for k in [1, 3, 5]:
        if topk_total[i] > 0:
            accuracy = (topk_correct[k][i] / topk_total[i]) * 100
            print(f"  Top-{k}: {accuracy:.2f}%")
        else:
            print(f"  Top-{k}: No samples available for this class.")