In [1]:
import os
from glob import glob
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from PIL import Image
import random
from utils import plot_sequences
from tqdm import tqdm
import time
import copy
from torch import nn, optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision.models.video import r3d_18
from torch.utils.tensorboard import SummaryWriter


In [None]:
def set_seed(seed):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
set_seed(0)

In [None]:


class Custom3DDataset(Dataset):
    def __init__(self, root_dir, sequence_length=16, sampling="single-random", transform=None):
        self.root_dir = root_dir
        self.sequence_length = sequence_length
        self.transform = transform
        self.sampling = sampling
        self.sequences = self._create_sequences()

    def _create_sequences(self):
        sequences = []
        for activity in os.listdir(self.root_dir):
            activity_dir = os.path.join(self.root_dir, activity)
            frames = sorted(glob(os.path.join(activity_dir, '*.png')))
            grouped_frames = self._group_frames_by_subject_and_session(frames)
            for subject_session in grouped_frames.keys():
                frames = grouped_frames.get(subject_session, [])
                if len(frames) < self.sequence_length:
                    # Pad the sequence
                    sequence = self._pad_sequence(frames)
                    sequences.append((sequence, activity))
                else:
                    if self.sampling == "multiple-consecutive":
                        for i in range(0, len(frames) - self.sequence_length + 1):
                            sequence = frames[i:i + self.sequence_length]
                            sequences.append((sequence, activity))
                    elif self.sampling == "multiple-random":
                        for _ in range(0, len(frames) - self.sequence_length + 1):
                            start_idx = random.randint(0, len(frames) - self.sequence_length)
                            sequence = frames[start_idx:start_idx + self.sequence_length]
                            sequences.append((sequence, activity))
                    elif self.sampling == "single-random":
                        sequence = sorted(random.sample(frames, self.sequence_length))
                        sequences.append((sequence, activity))
                # print(f'Grouped sequence {subject_session} {activity}: {sequence}')  # Print the grouped sequence
        return sequences

    def _pad_sequence(self, frames):
        if len(frames) == 0:
            return frames  # Avoid division by zero if frames is empty
        while len(frames) < self.sequence_length:
            frames.append(frames[-1])  # Repeat the last frame
        return frames

    def _group_frames_by_subject_and_session(self, frames):
        grouped_frames = {}
        for frame in frames:
            filename = os.path.basename(frame)
            subject_session = '_'.join(filename.split('_')[:2])
            if subject_session not in grouped_frames:
                grouped_frames[subject_session] = []
            grouped_frames[subject_session].append(frame)
        return grouped_frames

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence, activity = self.sequences[idx]
        frames = []
        for frame_path in sequence:
            image = Image.open(frame_path).convert('RGB')  # Convert to RGB
            if self.transform:
                image = self.transform(image)
            frames.append(image)
        frames = torch.stack(frames)  # (T, C, H, W)
        print(f"frames shape {frames.shape}")
        label = self._get_label(activity)
        return frames, label

    def _get_label(self, activity):
        # Assuming class names are the activity names
        class_names = sorted(os.listdir(self.root_dir))
        label = class_names.index(activity)
        return label


# Define transforms
train_transforms = transforms.Compose([
    transforms.Resize((128, 171), interpolation=InterpolationMode.BILINEAR),
    transforms.CenterCrop((112, 112)),
    transforms.ToTensor(),
    transforms.Normalize([0.43216, 0.394666, 0.37645], [0.22803, 0.22145, 0.216989])
])

test_transforms = transforms.Compose([
    transforms.Resize((128, 171), interpolation=InterpolationMode.BILINEAR),
    transforms.CenterCrop((112, 112)),
    transforms.ToTensor(),
    transforms.Normalize([0.43216, 0.394666, 0.37645], [0.22803, 0.22145, 0.216989])
])

# Create datasets and dataloaders
train_dataset = Custom3DDataset(root_dir='/mnt/data-tmp/ghazal/DARai_DATA/rgb_dataset/train',
                                transform=train_transforms)
test_dataset = Custom3DDataset(root_dir='/mnt/data-tmp/ghazal/DARai_DATA/rgb_dataset/test', transform=test_transforms)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

plot_sequences(train_dataset , 2)

In [None]:
# Initialize the 3D ResNet model
model = r3d_18(weights = 'Default')

# Modify the final fully connected layer to match the number of classes
num_classes = len(os.listdir('/mnt/data-tmp/ghazal/DARai_DATA/rgb_dataset/train'))  # Assuming class names are in the train directory
model.fc = nn.Linear(model.fc.in_features, num_classes)

# Move the model to the appropriate device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.1)

# Initialize TensorBoard SummaryWriter
# writer = SummaryWriter('./runs/activity_classification')

def train_model(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    running_corrects = 0
    start_time = time.time()

    for inputs, labels in tqdm(train_loader, desc="Training", leave=False):
        # if len(inputs.shape) == 4:
        #     inputs = inputs.unsqueeze(1)  # Add channel dimension

        inputs = inputs.to(device)  # (B, C, T, H, W)
        print(f'input shape {inputs.shape}')
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)  # Forward pass
        loss = criterion(outputs, labels)
        loss.backward()  # Backward pass
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, preds = torch.max(outputs, 1)
        running_corrects += torch.sum(preds == labels.data)

    epoch_loss = running_loss / len(train_loader.dataset)
    epoch_acc = running_corrects.double() / len(train_loader.dataset)
    epoch_time = time.time() - start_time
    print(f"Training epoch took: {epoch_time:.2f}s")

    return epoch_loss, epoch_acc

def evaluate_model(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    start_time = time.time()

    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc="Validation", leave=False):
            # if len(inputs.shape) == 4:
            #     inputs = inputs.unsqueeze(1)  # Add channel dimension

            inputs = inputs.to(device)  # (B, C, T, H, W)
            labels = labels.to(device)

            outputs = model(inputs)  # Forward pass
            loss = criterion(outputs, labels)

            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            running_corrects += torch.sum(preds == labels.data)

    epoch_loss = running_loss / len(val_loader.dataset)
    epoch_acc = running_corrects.double() / len(val_loader.dataset)
    epoch_time = time.time() - start_time
    print(f"Validation epoch took: {epoch_time:.2f}s")

    return epoch_loss, epoch_acc

num_epochs = 25
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0

for epoch in range(num_epochs):
    print(f'Epoch {epoch}/{num_epochs - 1}')
    print('-' * 10)

    epoch_start_time = time.time()

    train_loss, train_acc = train_model(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = evaluate_model(model, test_loader, criterion, device)

    epoch_duration = time.time() - epoch_start_time
    print(f"Epoch {epoch} took: {epoch_duration:.2f}s")

    print(f'Training Loss: {train_loss:.4f} Acc: {train_acc:.4f}')
    print(f'Validation Loss: {val_loss:.4f} Acc: {val_acc:.4f}')

    # writer.add_scalar('Loss/train', train_loss, epoch)
    # writer.add_scalar('Loss/val', val_loss, epoch)
    # writer.add_scalar('Accuracy/train', train_acc, epoch)
    # writer.add_scalar('Accuracy/val', val_acc, epoch)

    if val_acc > best_acc:
        best_acc = val_acc
        best_model_wts = copy.deepcopy(model.state_dict())

    scheduler.step(val_loss)

print(f'Best val Acc: {best_acc:.4f}')

model.load_state_dict(best_model_wts)
# writer.close()

torch.save(model.state_dict(), 'best_3d_resnet_model.pth')