In [None]:
#check cuda works pytorch
import torch
print(torch.cuda.is_available())
print(torch.cuda.current_device())

In [None]:
"""
EXTRACT ENCODER FROM PIX2SEQ CHECKPOINT
"""

import tensorflow as tf
import numpy as np
import torch
import torch.nn as nn
from pathlib import Path
import json

# ============================================================================
# PART 1: LOAD PIX2SEQ CHECKPOINT
# ============================================================================

MODEL_DIR = "/home/AD/sachith/pix2seq/data/HAR_pretrained/ts_model"

# Find the latest checkpoint
checkpoint_path = tf. train.latest_checkpoint(MODEL_DIR)
print(f"Loading checkpoint: {checkpoint_path}")

# Load checkpoint
checkpoint = tf.train.load_checkpoint(checkpoint_path)

# Inspect variables in checkpoint
print("\n=== Checkpoint Variables ===")
var_shapes = checkpoint. get_variable_to_shape_map()
for var_name in sorted(var_shapes.keys())[:20]:  # Show first 20
    print(f"{var_name}: {var_shapes[var_name]}")

print(f"\n...  ({len(var_shapes)} total variables)")

# Filter encoder variables
encoder_vars = {k: v for k, v in var_shapes.items() 
                if 'encoder' in k.lower() or 'backbone' in k.lower() or 'image' in k.lower()}

print(f"\n=== Encoder Variables ({len(encoder_vars)}) ===")
for var_name in sorted(encoder_vars. keys())[:20]:
    print(f"{var_name}: {encoder_vars[var_name]}")

In [None]:
# ============================================================================
# PART 2: IDENTIFY ENCODER ARCHITECTURE
# ============================================================================

def identify_encoder_type(var_names):
    """Identify what type of encoder is used."""
    var_names_str = ' '.join(var_names).lower()
    
    if 'vit' in var_names_str or 'vision_transformer' in var_names_str: 
        return 'ViT'
    elif 'resnet' in var_names_str: 
        return 'ResNet'
    elif 'efficientnet' in var_names_str:
        return 'EfficientNet'
    elif 'convnet' in var_names_str or 'conv' in var_names_str: 
        return 'ConvNet'
    else:
        return 'Unknown'

encoder_type = identify_encoder_type(list(var_shapes.keys()))
print(f"\n=== Detected Encoder Type:  {encoder_type} ===")

In [None]:
# ============================================================================
# PART 3: EXTRACT ENCODER WEIGHTS
# ============================================================================

def extract_encoder_weights(checkpoint, encoder_type='ViT'):
    """Extract encoder weights from checkpoint."""
    
    weights = {}
    
    # Get all variable names
    var_names = checkpoint.get_variable_to_shape_map().keys()
    
    # Filter encoder variables (adjust patterns based on your model)
    encoder_patterns = [
        'encoder',
        'backbone',
        'image_encoder',
        'visual_encoder',
        'feature_extractor',
        # Add patterns specific to your architecture
    ]
    
    for var_name in var_names: 
        # Check if this is an encoder variable
        is_encoder_var = any(pattern in var_name. lower() for pattern in encoder_patterns)
        
        # Exclude decoder variables
        is_decoder_var = any(pattern in var_name.lower() 
                            for pattern in ['decoder', 'output_head', 'token_embed'])
        
        if is_encoder_var and not is_decoder_var:
            try:
                tensor = checkpoint.get_tensor(var_name)
                weights[var_name] = tensor
                print(f"Extracted:  {var_name} {tensor.shape}")
            except Exception as e:
                print(f"Failed to extract {var_name}: {e}")
    
    return weights

encoder_weights = extract_encoder_weights(checkpoint, encoder_type)

print(f"\n=== Extracted {len(encoder_weights)} encoder weight tensors ===")

In [None]:
# ============================================================================
# PART 4: SAVE ENCODER WEIGHTS
# ============================================================================

SAVE_DIR = "/home/AD/sachith/pix2seq/data/HAR_pretrained/extracted_encoder"
Path(SAVE_DIR).mkdir(parents=True, exist_ok=True)

# Save as numpy archive
np.savez(f"{SAVE_DIR}/encoder_weights.npz", **encoder_weights)

# Save metadata
metadata = {
    'encoder_type': encoder_type,
    'checkpoint_path': checkpoint_path,
    'num_weights': len(encoder_weights),
    'variable_names': list(encoder_weights.keys()),
    'shapes': {k: list(v.shape) for k, v in encoder_weights.items()}
}

with open(f"{SAVE_DIR}/encoder_metadata.json", 'w') as f:
    json.dump(metadata, f, indent=2)

print(f"\n✓ Saved encoder weights to {SAVE_DIR}")

In [None]:
# %%
# ============================================================================
# COMPLETE CONVERSION SCRIPT - CLEAN VERSION
# ============================================================================

import tensorflow as tf
import numpy as np
import torch
from pathlib import Path
import json

# Paths
MODEL_DIR = "/home/AD/sachith/pix2seq/data/HAR_pretrained/ts_model"
SAVE_DIR = "/home/AD/sachith/pix2seq/data/HAR_pretrained/extracted_encoder"
Path(SAVE_DIR).mkdir(parents=True, exist_ok=True)

print("="*80)
print("PIX2SEQ ENCODER EXTRACTION & CONVERSION")
print("="*80)

# Load TensorFlow checkpoint
checkpoint_path = tf.train.latest_checkpoint(MODEL_DIR)
print(f"\n1. Loading TensorFlow checkpoint: {checkpoint_path}")
checkpoint = tf.train.load_checkpoint(checkpoint_path)

# Extract encoder variables
var_shapes = checkpoint.get_variable_to_shape_map()
encoder_vars = {k: v for k, v in var_shapes.items() 
                if 'encoder' in k.lower() and 'decoder' not in k.lower()}

print(f"   Found {len(encoder_vars)} encoder variables")

# Extract weights
encoder_weights = {}
for var_name in encoder_vars.keys():
    tensor = checkpoint.get_tensor(var_name)
    encoder_weights[var_name] = tensor

print(f"   Extracted {len(encoder_weights)} weight tensors")

# %%
# ============================================================================
# CONVERT TO PYTORCH FORMAT
# ============================================================================

def convert_to_pytorch(tf_weights, hidden_dim=768, num_heads=12):
    """Convert TensorFlow Pix2Seq encoder weights to PyTorch format."""
    
    pt_weights = {}
    converted_count = 0
    
    for tf_name, tf_tensor in tf_weights.items():
        pt_tensor = None
        pt_name = None
        
        # STEM CONV: [H, W, C_in, C_out] -> [C_out, C_in, H, W]
        if 'stem_conv/kernel' in tf_name:
            pt_tensor = np.transpose(tf_tensor, (3, 2, 0, 1))
            pt_name = 'stem_conv.weight'
        elif 'stem_conv/bias' in tf_name:
            pt_tensor = tf_tensor
            pt_name = 'stem_conv.bias'
        
        # STEM LAYER NORM
        elif 'stem_ln/gamma' in tf_name:
            pt_tensor = tf_tensor
            pt_name = 'stem_ln.weight'
        elif 'stem_ln/beta' in tf_name:
            pt_tensor = tf_tensor
            pt_name = 'stem_ln.bias'
        
        # OUTPUT LAYER NORM
        elif 'output_ln/gamma' in tf_name:
            pt_tensor = tf_tensor
            pt_name = 'output_ln.weight'
        elif 'output_ln/beta' in tf_name:
            pt_tensor = tf_tensor
            pt_name = 'output_ln.bias'
        
        # ENCODER LAYERS
        elif 'enc_layers/' in tf_name:
            import re
            layer_match = re.search(r'enc_layers/(\d+)/', tf_name)
            if not layer_match:
                continue
            
            layer_idx = int(layer_match.group(1))
            
            # ATTENTION LAYERS
            if '/mha/' in tf_name:
                # Q/K/V: [hidden, num_heads, head_dim] -> [hidden, hidden]
                if '_query_dense/kernel' in tf_name:
                    pt_tensor = tf_tensor.reshape(hidden_dim, -1).T if len(tf_tensor.shape) == 3 else tf_tensor.T
                    pt_name = f'encoder_layers.{layer_idx}.mha.query_dense.weight'
                elif '_query_dense/bias' in tf_name:
                    pt_tensor = tf_tensor.flatten() if len(tf_tensor.shape) > 1 else tf_tensor
                    pt_name = f'encoder_layers.{layer_idx}.mha.query_dense.bias'
                elif '_key_dense/kernel' in tf_name:
                    pt_tensor = tf_tensor.reshape(hidden_dim, -1).T if len(tf_tensor.shape) == 3 else tf_tensor.T
                    pt_name = f'encoder_layers.{layer_idx}.mha.key_dense.weight'
                elif '_key_dense/bias' in tf_name:
                    pt_tensor = tf_tensor.flatten() if len(tf_tensor.shape) > 1 else tf_tensor
                    pt_name = f'encoder_layers.{layer_idx}.mha.key_dense.bias'
                elif '_value_dense/kernel' in tf_name:
                    pt_tensor = tf_tensor.reshape(hidden_dim, -1).T if len(tf_tensor.shape) == 3 else tf_tensor.T
                    pt_name = f'encoder_layers.{layer_idx}.mha.value_dense.weight'
                elif '_value_dense/bias' in tf_name:
                    pt_tensor = tf_tensor.flatten() if len(tf_tensor.shape) > 1 else tf_tensor
                    pt_name = f'encoder_layers.{layer_idx}.mha.value_dense.bias'
                # Output: [num_heads, head_dim, hidden] -> [hidden, hidden]
                elif '_output_dense/kernel' in tf_name:
                    pt_tensor = tf_tensor.reshape(-1, hidden_dim).T if len(tf_tensor.shape) == 3 else tf_tensor.T
                    pt_name = f'encoder_layers.{layer_idx}.mha.output_dense.weight'
                elif '_output_dense/bias' in tf_name:
                    pt_tensor = tf_tensor
                    pt_name = f'encoder_layers.{layer_idx}.mha.output_dense.bias'
            
            # ATTENTION LAYER NORM
            elif '/mha_ln/gamma' in tf_name:
                pt_tensor = tf_tensor
                pt_name = f'encoder_layers.{layer_idx}.mha_ln.weight'
            elif '/mha_ln/beta' in tf_name:
                pt_tensor = tf_tensor
                pt_name = f'encoder_layers.{layer_idx}.mha_ln.bias'
            
            # MLP LAYERS
            elif '/mlp/' in tf_name:
                if 'dense1/kernel' in tf_name:
                    pt_tensor = tf_tensor.T
                    pt_name = f'encoder_layers.{layer_idx}.mlp.dense1.weight'
                elif 'dense1/bias' in tf_name:
                    pt_tensor = tf_tensor
                    pt_name = f'encoder_layers.{layer_idx}.mlp.dense1.bias'
                elif 'dense2/kernel' in tf_name:
                    pt_tensor = tf_tensor.T
                    pt_name = f'encoder_layers.{layer_idx}.mlp.dense2.weight'
                elif 'dense2/bias' in tf_name:
                    pt_tensor = tf_tensor
                    pt_name = f'encoder_layers.{layer_idx}.mlp.dense2.bias'
                # MLP Layer Norm (nested inside mlp module)
                elif 'layernorms/0/gamma' in tf_name:
                    pt_tensor = tf_tensor
                    pt_name = f'encoder_layers.{layer_idx}.mlp.layernorm.weight'
                elif 'layernorms/0/beta' in tf_name:
                    pt_tensor = tf_tensor
                    pt_name = f'encoder_layers.{layer_idx}.mlp.layernorm.bias'
        
        if pt_tensor is not None and pt_name is not None:
            pt_weights[pt_name] = torch.from_numpy(pt_tensor)
            converted_count += 1
    
    return pt_weights, converted_count

print("\n2. Converting to PyTorch format...")
pt_weights, converted_count = convert_to_pytorch(encoder_weights)
print(f"   Converted {converted_count} parameters")

# Save PyTorch weights
torch_path = f"{SAVE_DIR}/encoder_weights.pth"
torch.save(pt_weights, torch_path)
print(f"\n3. Saved PyTorch weights: {torch_path}")

# Save metadata
metadata = {
    'source': checkpoint_path,
    'architecture': 'ViT-Base',
    'hidden_dim': 768,
    'num_layers': 12,
    'num_heads': 12,
    'num_parameters': converted_count,
    'parameter_names': list(pt_weights.keys())
}

with open(f"{SAVE_DIR}/metadata.json", 'w') as f:
    json.dump(metadata, f, indent=2)

print(f"   Saved metadata: {SAVE_DIR}/metadata.json")



In [None]:
from models.encoder_cls import EncoderClassifier

# ============================================================================
# USAGE
# ============================================================================

if __name__ == '__main__':
    # Create model
    model = EncoderClassifier(num_classes=6, pretrained_encoder_path='/home/AD/sachith/pix2seq/data/HAR_pretrained/extracted_encoder/encoder_weights.pth', freeze_encoder=True, hidden_dims=[512], dropout=0.1, image_size=224)
    
    # Test
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    # Dummy input
    x = torch.randn(2, 3, 224, 224).to(device)
    logits = model(x)
    
    print(f"\nTest output shape: {logits.shape}")  # [2, 6]
    print(f"✓ Model ready!")

In [None]:
# %%
# ============================================================================
# PART 7: TEST LOADING WITH PYTORCH WEIGHTS
# ============================================================================

from models.encoder_cls import EncoderClassifier

print("\n" + "="*80)
print("TESTING WEIGHT LOADING")
print("="*80)

model = EncoderClassifier(
    num_classes=6,
    pretrained_encoder_path=f'{SAVE_DIR}/encoder_weights.pth',
    freeze_encoder=True,
    hidden_dims=[512],
    dropout=0.1,
    image_size=224
)

print("\n✓ Model created successfully!")

# Test forward pass
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

x = torch.randn(2, 3, 224, 224).to(device)
with torch.no_grad():
    logits = model(x)

print(f"\nTest output shape: {logits.shape}")  # [2, 6]
print(f"✓ Forward pass successful!")

# %%
# Verify which parameters are trainable
print("\n" + "="*80)
print("TRAINABLE PARAMETERS")
print("="*80)

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())

print(f"Trainable: {trainable:,}")
print(f"Total: {total:,}")
print(f"Frozen: {total - trainable:,}")

print("\nTrainable layers:")
for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"  {name}: {tuple(param.shape)}")

In [None]:
# %%
# ============================================================================
# VERIFY CONVERSION
# ============================================================================

print("\n" + "="*80)
print("VERIFICATION")
print("="*80)

from models.encoder_cls import EncoderClassifier

# Load model
model = EncoderClassifier(
    num_classes=6,
    pretrained_encoder_path=torch_path,
    freeze_encoder=True,
    hidden_dims=[512],
    dropout=0.1,
    image_size=224
)

# Test forward pass
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()

x = torch.randn(2, 3, 224, 224).to(device)
with torch.no_grad():
    logits = model(x)

print(f"\n✓ Forward pass successful!")
print(f"  Input shape:  {tuple(x.shape)}")
print(f"  Output shape: {tuple(logits.shape)}")

# Parameter counts
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
frozen = total - trainable

print(f"\n  Total parameters:     {total:,}")
print(f"  Frozen (encoder):     {frozen:,}")
print(f"  Trainable (head):     {trainable:,}")

print("\n" + "="*80)
print("✓ CONVERSION COMPLETE!")
print("="*80)
print(f"\nPretrained encoder ready at:")
print(f"  {torch_path}")
print(f"\nYou can now fine-tune on your HAR dataset!")

In [None]:
# %%
# ============================================================================
# QUICK TEST
# ============================================================================

print("\n" + "="*80)
print("QUICK TEST - VERIFYING EVERYTHING WORKS")
print("="*80)

from models.encoder_cls import EncoderClassifier
import torch

# Load model
model = EncoderClassifier(
    num_classes=6,
    pretrained_encoder_path='/home/AD/sachith/pix2seq/data/HAR_pretrained/extracted_encoder/encoder_weights.pth',
    freeze_encoder=True,
    hidden_dims=[512],
    dropout=0.1,
    image_size=224
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()

# Test batch
batch_size = 8
x = torch.randn(batch_size, 3, 224, 224).to(device)

with torch.no_grad():
    logits = model(x)
    probs = torch.softmax(logits, dim=1)

print(f"\n✓ Test passed!")
print(f"  Input:  {tuple(x.shape)}")
print(f"  Logits: {tuple(logits.shape)}")
print(f"  Probs:  {tuple(probs.shape)}")
print(f"  Prob sum per sample: {probs.sum(dim=1)}")  # Should be ~1.0

# Check gradients
print(f"\n✓ Gradient flow check:")
print(f"  Encoder frozen: {not any(p.requires_grad for p in model.encoder.parameters())}")
print(f"  Head trainable: {all(p.requires_grad for p in model.classifier.parameters())}")

print("\n" + "="*80)
print("ALL CHECKS PASSED - READY FOR TRAINING!")
print("="*80)

In [None]:
"""
Convert Multivariate HAR to Multi-Channel Images
Handles [N, C, L] format correctly
"""

import torch
import numpy as np
from PIL import Image, ImageDraw
import os
from tqdm import tqdm

class MultiChannelTimeSeriesImageConverter:
    """
    Convert multivariate time series to multiple images (one per channel).
    """
    
    def __init__(self, 
                 image_height: int = 224,
                 image_width: int = 224,
                 viz_type: str = 'line_plot'):
        self.image_height = image_height
        self.image_width = image_width
        self.viz_type = viz_type
    
    def convert_multivariate_sequence(self, x: torch.Tensor) -> torch.Tensor:
        """
        Convert multivariate sequence to multiple images.
        
        Args:
            x: [C, L] tensor (C channels, L timesteps)
        
        Returns:
            images: [C, 3, H, W] tensor (one image per channel)
        """
        C, L = x.shape
        images = []
        
        for c in range(C):
            channel_data = x[c]  # [L]
            img_array = self._create_line_plot_fast(channel_data.cpu().numpy())
            # [H, W, 3] -> [3, H, W]
            img_tensor = torch.from_numpy(img_array).permute(2, 0, 1).float() / 255.0
            images.append(img_tensor)
        
        return torch.stack(images)  # [C, 3, H, W]
    
    def convert_batch(self, sequences: torch.Tensor) -> torch.Tensor:
        """
        Convert batch of multivariate sequences.
        
        Args:
            sequences: [B, C, L] tensor
        
        Returns:
            images: [B, C, 3, H, W] tensor
        """
        B, C, L = sequences.shape
        batch_images = []
        
        for i in range(B):
            channel_images = self.convert_multivariate_sequence(sequences[i])
            batch_images.append(channel_images)
        
        return torch.stack(batch_images)  # [B, C, 3, H, W]
    
    def _create_line_plot_fast(self, x: np.ndarray) -> np.ndarray:
        """Fast line plot (SAME as Pix2Seq training)."""
        img = Image.new('RGB', (self.image_width, self.image_height), color='white')
        draw = ImageDraw.Draw(img)
        
        x_min, x_max = x.min(), x.max()
        if x_max - x_min < 1e-8:
            x_norm = np.ones_like(x) * 0.5
        else:
            x_norm = (x - x_min) / (x_max - x_min)
        
        margin = int(self.image_height * 0.05)
        y_coords = self.image_height - margin - (x_norm * (self.image_height - 2*margin)).astype(int)
        x_coords = np.linspace(0, self.image_width-1, len(x)).astype(int)
        
        points = list(zip(x_coords.tolist(), y_coords.tolist()))
        if len(points) > 1:
            draw.line(points, fill='blue', width=2)
        
        return np.array(img)


def convert_har_multivariate_to_images(
    data_dir='data/HAR',
    output_dir='data/HAR/multichannel_images',
    image_size=224,
    viz_type='line_plot'
):
    """
    Convert multivariate HAR to multi-channel images.
    Handles [N, C, L] format correctly.
    """
    
    print("="*80)
    print("CONVERTING MULTIVARIATE HAR TO MULTI-CHANNEL IMAGES")
    print("="*80)
    print(f"Configuration:")
    print(f"  - Image size: {image_size}x{image_size}")
    print(f"  - Visualization: {viz_type}")
    print(f"  - Output: {output_dir}")
    
    # Load data
    print("\n[1/4] Loading HAR data...")
    train_dict = torch.load(f'{data_dir}/train.pt')
    val_dict = torch.load(f'{data_dir}/val.pt')
    test_dict = torch.load(f'{data_dir}/test.pt')
    
    X_train = train_dict['samples']  # [N, 9, 128]
    y_train = train_dict['labels']   # [N]
    X_val = val_dict['samples']      # [N, 9, 128]
    y_val = val_dict['labels']       # [N]
    X_test = test_dict['samples']    # [N, 9, 128]
    y_test = test_dict['labels']     # [N]
    
    print(f"  Train: {X_train.shape}, labels: {y_train.shape}")
    print(f"  Val: {X_val.shape}, labels: {y_val.shape}")
    print(f"  Test: {X_test.shape}, labels: {y_test.shape}")
    
    # Verify shapes
    assert X_train.dim() == 3, f"Expected 3D tensor, got {X_train.dim()}D"
    assert X_train.shape[1] == 9, f"Expected 9 channels, got {X_train.shape[1]}"
    assert X_train.shape[2] == 128, f"Expected 128 timesteps, got {X_train.shape[2]}"
    
    # Initialize converter
    print("\n[2/4] Initializing multi-channel converter...")
    converter = MultiChannelTimeSeriesImageConverter(
        image_height=image_size,
        image_width=image_size,
        viz_type=viz_type
    )
    print("  ✓ Converter ready")
    
    # Convert
    print("\n[3/4] Converting sequences to multi-channel images...")
    batch_size = 50
    
    # Train
    print("  Converting train...")
    train_images = []
    for i in tqdm(range(0, len(X_train), batch_size), desc="Train"):
        batch = X_train[i:i+batch_size]
        batch_images = converter.convert_batch(batch)
        train_images.append(batch_images)
    train_images = torch.cat(train_images, dim=0)
    print(f"    ✓ Train images: {train_images.shape}")  # [N, 9, 3, 224, 224]
    
    # Val
    print("  Converting val...")
    val_images = []
    for i in tqdm(range(0, len(X_val), batch_size), desc="Val"):
        batch = X_val[i:i+batch_size]
        batch_images = converter.convert_batch(batch)
        val_images.append(batch_images)
    val_images = torch.cat(val_images, dim=0)
    print(f"    ✓ Val images: {val_images.shape}")
    
    # Test
    print("  Converting test...")
    test_images = []
    for i in tqdm(range(0, len(X_test), batch_size), desc="Test"):
        batch = X_test[i:i+batch_size]
        batch_images = converter.convert_batch(batch)
        test_images.append(batch_images)
    test_images = torch.cat(test_images, dim=0)
    print(f"    ✓ Test images: {test_images.shape}")
    
    # Verify all conversions match labels
    assert len(train_images) == len(y_train), "Train size mismatch!"
    assert len(val_images) == len(y_val), "Val size mismatch!"
    assert len(test_images) == len(y_test), "Test size mismatch!"
    
    # Save
    print("\n[4/4] Saving datasets...")
    os.makedirs(output_dir, exist_ok=True)
    
    torch.save({
        'images': train_images,      # [N, 9, 3, 224, 224]
        'labels': y_train,           # [N]
        'sequences': X_train,        # [N, 9, 128]
        'num_channels': 9,
        'sequence_length': 128,
        'image_size': image_size
    }, f'{output_dir}/train_multichannel_images.pt')
    
    torch.save({
        'images': val_images,
        'labels': y_val,
        'sequences': X_val,
        'num_channels': 9,
        'sequence_length': 128,
        'image_size': image_size
    }, f'{output_dir}/val_multichannel_images.pt')
    
    torch.save({
        'images': test_images,
        'labels': y_test,
        'sequences': X_test,
        'num_channels': 9,
        'sequence_length': 128,
        'image_size': image_size
    }, f'{output_dir}/test_multichannel_images.pt')
    
    # Save metadata
    metadata = {
        'num_classes': 6,
        'class_names': {
            0: 'WALKING',
            1: 'WALKING_UPSTAIRS',
            2: 'WALKING_DOWNSTAIRS',
            3: 'SITTING',
            4: 'STANDING',
            5: 'LAYING'
        },
        'num_channels': 9,
        'channel_names': [
            'body_acc_x', 'body_acc_y', 'body_acc_z',
            'body_gyro_x', 'body_gyro_y', 'body_gyro_z',
            'total_acc_x', 'total_acc_y', 'total_acc_z'
        ],
        'sequence_length': 128,
        'image_size': image_size,
        'viz_type': viz_type,
        'train_samples': len(train_images),
        'val_samples': len(val_images),
        'test_samples': len(test_images)
    }
    
    import json
    with open(f'{output_dir}/metadata.json', 'w') as f:
        json.dump(metadata, f, indent=2)
    
    # Save sample visualizations
    print("\nSaving sample visualizations...")
    sample_dir = f'{output_dir}/samples'
    os.makedirs(sample_dir, exist_ok=True)
    
    # Save first 5 samples, showing all 9 channels
    for i in range(min(5, len(train_images))):
        sample_img = train_images[i]  # [9, 3, 224, 224]
        label = y_train[i].item()
        class_name = metadata['class_names'][label]
        
        # Create grid of 9 channel images (3x3)
        import torchvision.utils as vutils
        grid = vutils.make_grid(sample_img, nrow=3, padding=2)
        grid_np = grid.permute(1, 2, 0).numpy() * 255
        grid_np = grid_np.astype(np.uint8)
        
        Image.fromarray(grid_np).save(
            f'{sample_dir}/sample_{i}_label_{label}_{class_name}_all_channels.png'
        )
    
    print("\n" + "="*80)
    print("✓ CONVERSION COMPLETE!")
    print("="*80)
    print(f"\nFiles created:")
    print(f"  - {output_dir}/train_multichannel_images.pt")
    print(f"      Images: {train_images.shape}")
    print(f"      Labels: {y_train.shape}")
    print(f"  - {output_dir}/val_multichannel_images.pt")
    print(f"      Images: {val_images.shape}")
    print(f"      Labels: {y_val.shape}")
    print(f"  - {output_dir}/test_multichannel_images.pt")
    print(f"      Images: {test_images.shape}")
    print(f"      Labels: {y_test.shape}")
    print(f"  - {output_dir}/metadata.json")
    print(f"  - {sample_dir}/*.png (sample visualizations)")
    print(f"\n✅ Each sequence converted to 9 separate images (one per channel)")
    print(f"✅ Images use SAME format as Pix2Seq training")
    
    return train_images, val_images, test_images

In [None]:
# convert_har_multivariate_to_images()

In [None]:
# %%
# ============================================================================
# TRAINING SCRIPT
# ============================================================================

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from pathlib import Path
from tqdm import tqdm
import json

from data_loader import create_har_dataloaders
from models.encoder_cls import EncoderClassifier

# Enable these at the top of your script
torch.backends.cudnn.benchmark = True  # Auto-tune conv operations
torch.backends.cuda.matmul.allow_tf32 = True  # Faster matmul on Ampere+ GPUs
torch.backends.cudnn.allow_tf32 = True


# Use mixed precision training
from torch.amp import autocast
from torch.cuda.amp import GradScaler

def train_har_classifier(
    mode='flatten',
    batch_size=32,
    num_epochs=20,
    learning_rate=1e-3,
    device='cuda',
    save_dir='checkpoints/har_classifier'
):
    """
    Train HAR classifier with pretrained Pix2Seq encoder.
    
    Args:
        mode: 'flatten', 'average', or 'first'
        batch_size: Training batch size
        num_epochs: Number of epochs
        learning_rate: Learning rate
        device: 'cuda' or 'cpu'
        save_dir: Directory to save checkpoints
    """
    
    Path(save_dir).mkdir(parents=True, exist_ok=True)
    device = torch.device(device if torch.cuda.is_available() else 'cpu')
    
    print("="*80)
    print("TRAINING HAR CLASSIFIER WITH PIX2SEQ ENCODER")
    print("="*80)
    print(f"Configuration:")
    print(f"  Mode: {mode}")
    print(f"  Batch size: {batch_size}")
    print(f"  Epochs: {num_epochs}")
    print(f"  Learning rate: {learning_rate}")
    print(f"  Device: {device}")
    
    # Create dataloaders
    print("\n[1/5] Creating dataloaders...")
    train_loader, val_loader, test_loader = create_har_dataloaders(
        mode=mode,
        batch_size=batch_size,
        use_augmentation=True
    )
    
    # Create model
    print("\n[2/5] Loading model with pretrained encoder...")
    model = EncoderClassifier(
        num_classes=6,
        pretrained_encoder_path='/home/AD/sachith/pix2seq/data/HAR_pretrained/extracted_encoder/encoder_weights.pth',
        freeze_encoder=True,
        hidden_dims=[512],
        dropout=0.1,
        image_size=224
    )
    model = model.to(device)
    # Compile model (PyTorch 2.0+)
    if hasattr(torch, 'compile'):
        model = torch.compile(model, mode='reduce-overhead')
    
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    
    # Mixed precision scaler
    scaler = GradScaler()

    # Training loop
    print("\n[3/5] Training...")
    best_val_acc = 0.0
    train_losses = []
    val_accs = []
    
    for epoch in range(num_epochs):
        # Train
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            
            # Forward
            optimizer.zero_grad()
            # Mixed precision forward pass
            with autocast('cuda'):
                outputs = model(images)
                loss = criterion(outputs, labels)
            
            # Scaled backward pass
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            # Stats
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            train_total += labels.size(0)
            train_correct += predicted.eq(labels).sum().item()
            
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{100.*train_correct/train_total:.2f}%'
            })
        
        train_loss /= len(train_loader)
        train_acc = 100. * train_correct / train_total
        train_losses.append(train_loss)
        
        # Validate
        model.eval()
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()
        
        val_acc = 100. * val_correct / val_total
        val_accs.append(val_acc)
        
        # Update scheduler
        scheduler.step()
        
        # Print epoch summary
        print(f'\nEpoch {epoch+1}/{num_epochs}:')
        print(f'  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%')
        print(f'  Val Acc: {val_acc:.2f}%')
        print(f'  LR: {scheduler.get_last_lr()[0]:.6f}')
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
                'mode': mode
            }, f'{save_dir}/best_model.pth')
            print(f'  ✓ Saved best model (val_acc: {val_acc:.2f}%)')
    
    # Test evaluation
    print("\n[4/5] Evaluating on test set...")
    checkpoint = torch.load(f'{save_dir}/best_model.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    test_correct = 0
    test_total = 0
    
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc='Testing'):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            test_total += labels.size(0)
            test_correct += predicted.eq(labels).sum().item()
    
    test_acc = 100. * test_correct / test_total
    
    # Save training history
    print("\n[5/5] Saving results...")
    results = {
        'mode': mode,
        'best_val_acc': best_val_acc,
        'test_acc': test_acc,
        'train_losses': train_losses,
        'val_accs': val_accs,
        'num_epochs': num_epochs,
        'batch_size': batch_size,
        'learning_rate': learning_rate
    }
    
    with open(f'{save_dir}/training_results.json', 'w') as f:
        json.dump(results, f, indent=2)
    
    print("\n" + "="*80)
    print("TRAINING COMPLETE!")
    print("="*80)
    print(f"Best Val Accuracy: {best_val_acc:.2f}%")
    print(f"Test Accuracy: {test_acc:.2f}%")
    print(f"Model saved to: {save_dir}/best_model.pth")
    
    return model, results


In [None]:

# %%
# ============================================================================
# RUN TRAINING
# ============================================================================

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

import torch
print(f"Using GPU: {torch.cuda.current_device()}")
print(f"GPU Name: {torch.cuda.get_device_name(0)}")  # Will show GPU 1 as device 0

if __name__ == '__main__':
    # Training with FLATTEN mode (recommended)
    print("\n" + "="*80)
    print("TRAINING WITH FLATTEN MODE (9x more samples)")
    print("="*80)
    
    model_flatten, results_flatten = train_har_classifier(
        mode='flatten',
        batch_size=128,
        num_epochs=20,
        learning_rate=1e-3,
        save_dir='checkpoints/har_flatten'
    )
    