In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

# --- 1. Encoder Module ---
class TSEncoder(nn.Module):
    def __init__(self, input_dims, output_dims, hidden_dims=64, depth=3):
        super(TSEncoder, self).__init__()
        self.input_dims = input_dims
        self.output_dims = output_dims
        self.hidden_dims = hidden_dims
        
        layers = []
        current_in = input_dims
        for _ in range(depth):
            layers.extend([
                nn.Conv1d(current_in, hidden_dims, kernel_size=3, padding=1),
                nn.BatchNorm1d(hidden_dims),
                nn.ReLU()
            ])
            current_in = hidden_dims
        
        layers.append(nn.Conv1d(hidden_dims, output_dims, kernel_size=1))
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        # x: (batch, seq_len, input_dims)
        x = x.permute(0, 2, 1) # to (batch, input_dims, seq_len)
        out = self.network(x)
        return out.permute(0, 2, 1) # to (batch, seq_len, output_dims)

# --- 2. TS-TCC Core Model ---
class TS_TCC(nn.Module):
    def __init__(self, input_dims, latent_dims, temporal_unit_dims, n_heads=4, n_layers=2):
        super(TS_TCC, self).__init__()
        
        self.encoder = TSEncoder(input_dims=input_dims, output_dims=latent_dims)
        
        # Temporal Contrasting Module
        encoder_layers = nn.TransformerEncoderLayer(
            d_model=latent_dims, 
            nhead=n_heads, 
            dim_feedforward=latent_dims*2,
            dropout=0.1,
            batch_first=True,
            norm_first=True # Pre-norm as mentioned in the paper
        )
        self.temporal_transformer = nn.TransformerEncoder(encoder_layers, num_layers=n_layers)
        
        # This linear layer acts as Wk in the paper to predict future latent steps
        self.temporal_predictor = nn.Linear(latent_dims, latent_dims)

        # Contextual Contrasting Module (Projection Head)
        self.projector = nn.Sequential(
            nn.Linear(latent_dims, latent_dims),
            nn.ReLU(),
            nn.Linear(latent_dims, temporal_unit_dims)
        )

    def forward(self, x_strong, x_weak, future_k=5):
        # x_strong, x_weak: [batch, seq_len, input_dim]

        # --- Get latent representations ---
        z_strong = self.encoder(x_strong)
        z_weak = self.encoder(x_weak)
        
        # --- Temporal Contrasting ---
        # The past is all but the last k steps
        z_strong_past = z_strong[:, :-future_k, :]
        z_weak_past = z_weak[:, :-future_k, :]
        
        # The future is the last k steps
        z_strong_future = z_strong[:, -future_k:, :]
        z_weak_future = z_weak[:, -future_k:, :]
        
        # Get context vectors from the transformer
        c_strong = self.temporal_transformer(z_strong_past).mean(dim=1) # Average pooling for context
        c_weak = self.temporal_transformer(z_weak_past).mean(dim=1)

        # Cross-view prediction
        pred_from_strong = self.temporal_predictor(c_strong).unsqueeze(1).repeat(1, future_k, 1)
        pred_from_weak = self.temporal_predictor(c_weak).unsqueeze(1).repeat(1, future_k, 1)
        
        # --- Contextual Contrasting ---
        p_strong = self.projector(c_strong)
        p_weak = self.projector(c_weak)
        
        return pred_from_strong, z_weak_future, pred_from_weak, z_strong_future, p_strong, p_weak


# --- 3. Agent to handle training and representation extraction ---
class TS_TCC_Agent:
    def __init__(self, input_dims, latent_dims=64, temporal_unit_dims=128, 
                 lr=3e-4, lambda1=1.0, lambda2=0.7, temp=0.2, future_k=5):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")

        self.model = TS_TCC(
            input_dims=input_dims,
            latent_dims=latent_dims,
            temporal_unit_dims=temporal_unit_dims
        ).to(self.device)
        
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, weight_decay=3e-4)
        self.lambda1 = lambda1
        self.lambda2 = lambda2
        self.temp = temp
        self.future_k = future_k
    
    # --- Augmentation Methods ---
    def _jitter(self, x, sigma=0.1):
        return x + torch.randn_like(x) * sigma

    def _scale(self, x, sigma=0.1):
        factor = torch.randn(x.shape[0], 1, 1, device=self.device) * sigma + 1.0
        return x * factor

    def _permute(self, x, max_segments=5, seg_mode="random"):
        orig_steps = np.arange(x.shape[1])
        num_segs = np.random.randint(1, max_segments)
        ret = np.array_split(orig_steps, num_segs)
        np.random.shuffle(ret)
        ret = np.concatenate(ret)
        return x[:, ret, :]

    def _get_augmentations(self, x_batch):
        # Weak Augmentation
        x_weak = self._scale(self._jitter(x_batch))
        # Strong Augmentation
        x_strong = self._jitter(self._permute(x_batch))
        return x_strong, x_weak

    # --- Loss Calculation ---
    def _info_nce_loss(self, query, positive, negatives):
        # query: [batch, D], positive: [batch, D], negatives: [batch, K, D]
        query = F.normalize(query, dim=1)
        positive = F.normalize(positive, dim=1)
        negatives = F.normalize(negatives, dim=2)
        
        l_pos = (query * positive).sum(dim=1).unsqueeze(1) # [batch, 1]
        l_neg = torch.bmm(query.unsqueeze(1), negatives.transpose(1, 2)).squeeze(1) # [batch, K]
        
        logits = torch.cat([l_pos, l_neg], dim=1)
        labels = torch.zeros(logits.shape[0], dtype=torch.long, device=self.device)
        return F.cross_entropy(logits / self.temp, labels)

    def _calculate_loss(self, pred_from_strong, z_weak_future, pred_from_weak, z_strong_future, p_strong, p_weak):
        batch_size = p_strong.shape[0]

        # Temporal Contrastive Loss (L_TC)
        loss_tc = F.mse_loss(pred_from_strong, z_weak_future) + F.mse_loss(pred_from_weak, z_strong_future)

        # Contextual Contrastive Loss (L_CC)
        p_strong_norm = F.normalize(p_strong, dim=1)
        p_weak_norm = F.normalize(p_weak, dim=1)
        
        sim_matrix = torch.matmul(p_strong_norm, p_weak_norm.T) # [batch, batch]
        
        # Positive is the diagonal, negative is off-diagonal
        positives = torch.diag(sim_matrix)
        
        # Create a mask to select negatives for each sample in the batch
        mask = ~torch.eye(batch_size, dtype=torch.bool, device=self.device)
        negatives = sim_matrix[mask].reshape(batch_size, batch_size - 1)
        
        logits = torch.cat([positives.unsqueeze(1), negatives], dim=1)
        labels = torch.zeros(batch_size, dtype=torch.long, device=self.device)
        loss_cc = F.cross_entropy(logits / self.temp, labels)
        
        total_loss = self.lambda1 * loss_tc + self.lambda2 * loss_cc
        return total_loss, loss_tc, loss_cc

    def train(self, train_loader, epochs=50):
        self.model.train()
        print("--- Starting TS-TCC Pre-training ---")
        for epoch in range(epochs):
            total_loss, total_tc, total_cc = 0, 0, 0
            for i, (x_batch,) in enumerate(train_loader):
                x_batch = x_batch.to(self.device)
                x_strong, x_weak = self._get_augmentations(x_batch)

                self.optimizer.zero_grad()
                
                outputs = self.model(x_strong, x_weak, future_k=self.future_k)
                loss, tc, cc = self._calculate_loss(*outputs)
                
                loss.backward()
                self.optimizer.step()
                
                total_loss += loss.item()
                total_tc += tc.item()
                total_cc += cc.item()
            
            avg_loss = total_loss / len(train_loader)
            avg_tc = total_tc / len(train_loader)
            avg_cc = total_cc / len(train_loader)
            print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f} (TC: {avg_tc:.4f}, CC: {avg_cc:.4f})")

    def get_representations(self, data_loader):
        self.model.eval()
        all_reps = []
        with torch.no_grad():
            for (x_batch,) in data_loader:
                x_batch = x_batch.to(self.device)
                
                # Get latent representation from the encoder
                z = self.model.encoder(x_batch)
                
                # Get context vector from the temporal module
                # The paper uses this context vector for downstream tasks
                c = self.model.temporal_transformer(z).mean(dim=1)
                
                all_reps.append(c.cpu().numpy())
        
        return np.concatenate(all_reps, axis=0)


# --- Example of how to use the agent ---
if __name__ == '__main__':
    # --- Dummy Parameters and Data for Demonstration ---
    BATCH_SIZE = 32
    SEQ_LEN = 128
    INPUT_DIMS = 9
    FUTURE_K = 10 # Predict last 10 steps

    # Create dummy data
    train_data = torch.randn(BATCH_SIZE * 10, SEQ_LEN, INPUT_DIMS)
    train_dataset = torch.utils.data.TensorDataset(train_data)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    
    # Create the agent
    agent = TS_TCC_Agent(
        input_dims=INPUT_DIMS,
        latent_dims=64,
        temporal_unit_dims=128,
        lr=3e-4,
        lambda1=1.0,
        lambda2=0.7,
        temp=0.2,
        future_k=FUTURE_K
    )
    
    # Train the model (unsupervised pre-training)
    agent.train(train_loader, epochs=5) # Keep low for demo
    
    # After training, extract representations for a downstream task (e.g., classification)
    print("\n--- Extracting representations for downstream task ---")
    representations = agent.get_representations(train_loader)
    
    print(f"Shape of extracted representations: {representations.shape}")
    print(f"Expected shape: ({len(train_data)}, temporal_unit_dims)")