In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from tqdm import tqdm
from typing import Tuple, Dict
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.utils.scheduler import cosine_schedule
from dataset import create_wall_dataloader

# Configs
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
INPUT_SHAPE = (2, 65, 65)  # 2 channel 65x65 images
ACTION_DIM = 2
STATE_DIM = 256  # Encoded state dimension
HIDDEN_DIM = 600  # GRU hidden dimension
BATCH_SIZE = 32
MOMENTUM = 0.996

In [4]:
train_loader = create_wall_dataloader(
    "/drive_reader/as16386/DL24FA/train",
    batch_size=BATCH_SIZE,
    train=True, 
    num_samples=None
)   

In [7]:
# Get one batch from the dataloader
train_iter = iter(train_loader)
batch = next(train_iter)

# Print batch type and contents
print("Batch type:", type(batch))
print("\nBatch is a named tuple with fields:", batch._fields)

print("\nStates tensor:")
print("- Shape:", batch.states.shape)
print("- Type:", batch.states.dtype)
print("- Device:", batch.states.device)

print("\nActions tensor:")
print("- Shape:", batch.actions.shape)
print("- Type:", batch.actions.dtype)
print("- Device:", batch.actions.device)

Batch type: <class 'dataset.WallSample'>

Batch is a named tuple with fields: ('states', 'locations', 'actions')

States tensor:
- Shape: torch.Size([32, 17, 2, 65, 65])
- Type: torch.float32
- Device: cuda:0

Actions tensor:
- Shape: torch.Size([32, 16, 2])
- Type: torch.float32
- Device: cuda:0


In [8]:
class StateEncoder(nn.Module):
    def __init__(self, state_dim: int = STATE_DIM):
        super().__init__()
        
        self.encoder = nn.Sequential(
            # 65x65x2 -> 32x32x32
            nn.Conv2d(2, 32, 4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            
            # 32x32x32 -> 16x16x64
            nn.Conv2d(32, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            
            # 16x16x64 -> 8x8x128
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            
            # 8x8x128 -> 4x4x256
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            
            nn.Flatten(),
            nn.Linear(4 * 4 * 256, state_dim)
        )
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.encoder(x)

# Test encoder
encoder = StateEncoder().to(DEVICE)
test_input = torch.randn(BATCH_SIZE, *INPUT_SHAPE).to(DEVICE)
test_output = encoder(test_input)
assert test_output.shape == (BATCH_SIZE, STATE_DIM), f"Expected shape {(BATCH_SIZE, STATE_DIM)}, got {test_output.shape}"
print(f"Encoder parameter count: {sum(p.numel() for p in encoder.parameters())}")

Encoder parameter count: 1739424


In [9]:
class GRUPredictor(nn.Module):
    def __init__(self, state_dim: int = STATE_DIM, action_dim: int = ACTION_DIM, hidden_dim: int = HIDDEN_DIM):
        super().__init__()
        
        self.state_proj = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ELU(),
            nn.LayerNorm(hidden_dim)
        )
        
        self.action_proj = nn.Linear(action_dim, hidden_dim)
        
        self.gru = nn.GRU(
            input_size=hidden_dim,
            hidden_size=hidden_dim,
            num_layers=1,
            batch_first=True
        )
        
        self.out_proj = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, state_dim)
        )
        
    def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        # Project state
        state = self.state_proj(state)  # [B, H]
        
        # Project action
        action = self.action_proj(action)  # [B, H]
        
        # Prepare for GRU
        action = action.unsqueeze(1)  # [B, 1, H]
        state = state.unsqueeze(0)    # [1, B, H]
        
        # GRU forward pass
        _, hidden = self.gru(action, state)
        output = self.out_proj(hidden[0])
        
        return output

# BYOL projection and prediction heads
class ProjectionHead(nn.Module):
    def __init__(self, input_dim: int = STATE_DIM):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 256)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

class PredictionHead(nn.Module):
    def __init__(self, input_dim: int = 256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 256)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

# Test GRU and heads
gru = GRUPredictor().to(DEVICE)
proj = ProjectionHead().to(DEVICE)
pred = PredictionHead().to(DEVICE)

test_state = torch.randn(BATCH_SIZE, STATE_DIM).to(DEVICE)
test_action = torch.randn(BATCH_SIZE, ACTION_DIM).to(DEVICE)

test_gru_out = gru(test_state, test_action)
test_proj_out = proj(test_gru_out)
test_pred_out = pred(test_proj_out)

assert test_gru_out.shape == (BATCH_SIZE, STATE_DIM)
assert test_proj_out.shape == (BATCH_SIZE, 256)
assert test_pred_out.shape == (BATCH_SIZE, 256)

In [10]:
class BYOLGRU(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Online networks
        self.encoder = StateEncoder()
        self.gru = GRUPredictor()
        self.projection = ProjectionHead()
        self.prediction = PredictionHead()
        
        # Target networks
        self.encoder_momentum = copy.deepcopy(self.encoder)
        self.projection_momentum = copy.deepcopy(self.projection)
        deactivate_requires_grad(self.encoder_momentum)
        deactivate_requires_grad(self.projection_momentum)
        
    def forward(self, state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # Online forward
        z1 = self.encoder(state)
        z2 = self.gru(z1, action)
        p1 = self.prediction(self.projection(z2))
        
        # Target forward
        with torch.no_grad():
            next_state = state  # In our case, we already have the next state
            z3 = self.projection_momentum(self.encoder_momentum(next_state))
        
        return p1, z3
    
    @torch.no_grad()
    def update_target(self, m: float):
        """Update momentum networks"""
        for online, target in zip(self.encoder.parameters(), self.encoder_momentum.parameters()):
            target.data = target.data * m + online.data * (1. - m)
        for online, target in zip(self.projection.parameters(), self.projection_momentum.parameters()):
            target.data = target.data * m + online.data * (1. - m)

# Test full model
model = BYOLGRU().to(DEVICE)
test_state = torch.randn(BATCH_SIZE, *INPUT_SHAPE).to(DEVICE)
test_action = torch.randn(BATCH_SIZE, ACTION_DIM).to(DEVICE)

p1, z3 = model(test_state, test_action)
assert p1.shape == (BATCH_SIZE, 256)
assert z3.shape == (BATCH_SIZE, 256)

In [15]:
def train_epoch(model: BYOLGRU, 
                dataloader: DataLoader, 
                optimizer: torch.optim.Optimizer,
                epoch: int,
                total_epochs: int) -> float:
    """Train for one epoch"""
    model.train()
    total_loss = 0
    
    for batch_idx, batch in enumerate(tqdm(dataloader)):
        # states shape: [batch, seq_len, channels, height, width]
        # actions shape: [batch, seq_len-1, action_dim]
        states, _, actions = batch
        batch_size, seq_len = states.shape[:2]
        
        # Update momentum parameter
        m = cosine_schedule(epoch, total_epochs, 0.996, 1)
        
        # Process each timestep in the sequence
        sequence_loss = 0
        for t in range(seq_len - 1):  # -1 because we need next state
            current_state = states[:, t]  # [batch, channels, height, width]
            next_state = states[:, t+1]   # [batch, channels, height, width]
            current_action = actions[:, t] # [batch, action_dim]
            
            # Forward pass
            p1, z3 = model(current_state, current_action)
            
            # Normalize projections
            p1 = F.normalize(p1, dim=-1)
            z3 = F.normalize(z3, dim=-1)
            
            # BYOL loss for this timestep
            loss = 2 - 2 * (p1 * z3).sum(dim=-1).mean()
            sequence_loss += loss
            
        # Average loss over sequence
        sequence_loss = sequence_loss / (seq_len - 1)
        
        # Backward pass
        optimizer.zero_grad()
        sequence_loss.backward()
        optimizer.step()
        
        # Update momentum networks
        model.update_target(m)
        
        total_loss += sequence_loss.item()
        
    return total_loss / len(dataloader)

# Training setup

# print("Starting to load data") 
# states, actions = load_data("/drive_reader/as16386/DL24FA/train")  # Update path
# dataset = TensorDataset(states, actions)
# dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
# print("Finished loading data") 
# Test shapes from dataloader
# test_states, test_actions = next(iter(dataloader))
# print(f"States shape: {test_states.shape}")
# print(f"Actions shape: {test_actions.shape}")

model = BYOLGRU().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
print("Starting training") 
num_epochs = 100
for epoch in range(num_epochs):
    avg_loss = train_epoch(model, train_loader, optimizer, epoch, num_epochs)
    print(f"Epoch {epoch}, Loss: {avg_loss:.4f}")
    
    if epoch % 10 == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }, f'byol_try_checkpoints/byol_gru_checkpoint_{epoch}.pt')

Starting training


  9%|â–‰         | 418/4594 [01:50<18:24,  3.78it/s]


KeyboardInterrupt: 