In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import random
import numpy as np

# Set random seed
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False



# Step 1: Updated dataset using new generate_sample logic (from previous discussion)
class PairedSensorDataset(Dataset):
    def __init__(self, num_samples=10000):
        self.raw_data = [self.generate_sample(i) for i in range(num_samples)]
        self._compute_normalization_stats()
        self.data = [self._normalize(sample) for sample in self.raw_data]

    def generate_sample(self, pair_id):
        house_id = random.randint(0, 9)
        occupancy = random.choice([0, 1])
        time_of_day = random.choice(['morning', 'afternoon', 'evening', 'night'])

        if time_of_day == 'morning':
            set_temp = 21.0
        elif time_of_day == 'afternoon':
            set_temp = 22.5
        elif time_of_day == 'evening':
            set_temp = 23.0
        else:
            set_temp = 20.0

        actual_temp = random.uniform(18.0, 26.0)
        humidity = random.uniform(30.0, 70.0)
        comfort_gap = set_temp - actual_temp
        humidity_factor = (humidity - 50) / 100
        occ_factor = 0.4 if occupancy else 0.1

        hvac = max(0.0, comfort_gap * occ_factor + humidity_factor)
        hvac = min(hvac, 2.5)
        base_power = 0.3 + 0.2 * occupancy + 0.4 * hvac + house_id * 0.05
        power = base_power + random.uniform(0.0, 0.2)

        # Add normalized pair_id as a shared signal between agent1 and agent2
        pair_token = pair_id / 10000.0

        agent1_obs = [set_temp, actual_temp, humidity, pair_token]
        agent2_obs = [power, hvac, house_id / 10.0, occupancy, pair_token]

        return {"agent1_obs": agent1_obs, "agent2_obs": agent2_obs}

    def _compute_normalization_stats(self):
        import numpy as np
        x1_list = [sample["agent1_obs"] for sample in self.raw_data]
        x2_list = [sample["agent2_obs"] for sample in self.raw_data]
        self.x1_min = torch.tensor(np.min(x1_list, axis=0), dtype=torch.float32)
        self.x1_max = torch.tensor(np.max(x1_list, axis=0), dtype=torch.float32)
        self.x2_min = torch.tensor(np.min(x2_list, axis=0), dtype=torch.float32)
        self.x2_max = torch.tensor(np.max(x2_list, axis=0), dtype=torch.float32)

    def _normalize(self, sample):
        x1 = torch.tensor(sample["agent1_obs"], dtype=torch.float32)
        x2 = torch.tensor(sample["agent2_obs"], dtype=torch.float32)
        x1_norm = (x1 - self.x1_min) / (self.x1_max - self.x1_min + 1e-8)
        x2_norm = (x2 - self.x2_min) / (self.x2_max - self.x2_min + 1e-8)
        return {"agent1_obs": x1_norm, "agent2_obs": x2_norm}

    def __getitem__(self, idx):
        sample = self.data[idx]
        return sample["agent1_obs"], sample["agent2_obs"]

    def __len__(self):
        return len(self.data)

# Step 2: Encoders + Projection Head + Cosine Similarity
class AgentEncoder(nn.Module):
    def __init__(self, input_dim, output_dim=128):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU()
        )
        self.projector = nn.Sequential(
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, output_dim)
        )

    def forward(self, x):
        features = self.encoder(x)
        projected = self.projector(features)
        return F.normalize(projected, dim=-1)  # L2 norm for cosine similarity
    
# Step 3: Contrastive training with cosine similarity
def train_contrastive():
    dataset = PairedSensorDataset()
    loader = DataLoader(dataset, batch_size=128, shuffle=True)

    encoder1 = AgentEncoder(input_dim=4)
    encoder2 = AgentEncoder(input_dim=5)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    encoder1.to(device)
    encoder2.to(device)

    optimizer = torch.optim.Adam(list(encoder1.parameters()) + list(encoder2.parameters()), lr=1e-4)

    initial_temperature = 0.2  # Start soft
    min_temperature = 0.01
    decay_rate = 0.99

    for epoch in range(1, 501):
        encoder1.train()
        encoder2.train()
        total_loss, total_acc = 0.0, 0.0

        # Update temperature (annealing)
        temperature = max(min_temperature, initial_temperature * (decay_rate ** epoch))

        for x1, x2 in loader:
            x1, x2 = x1.to(device), x2.to(device)
            z1 = encoder1(x1)
            z2 = encoder2(x2)

            sim_matrix = torch.matmul(z1, z2.T) / temperature

            labels = torch.arange(sim_matrix.size(0)).to(sim_matrix.device)

            loss_i = F.cross_entropy(sim_matrix, labels)
            loss_t = F.cross_entropy(sim_matrix.T, labels)
            loss = (loss_i + loss_t) / 2

            preds_i = sim_matrix.argmax(dim=1)
            preds_t = sim_matrix.T.argmax(dim=1)
            acc_i = (preds_i == labels).float().mean()
            acc_t = (preds_t == labels).float().mean()
            accuracy = (acc_i + acc_t) / 2

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_acc += accuracy.item()

        avg_loss = total_loss / len(loader)
        avg_acc = total_acc / len(loader)
        print(f"Epoch {epoch:3d} | Temp: {temperature:.4f} | Loss: {avg_loss:.4f} | Accuracy: {avg_acc:.4f}")

# Run training
train_contrastive()

Epoch   1 | Temp: 0.1980 | Loss: 4.7427 | Accuracy: 0.0222
Epoch   2 | Temp: 0.1960 | Loss: 3.9335 | Accuracy: 0.0673
Epoch   3 | Temp: 0.1941 | Loss: 3.3859 | Accuracy: 0.1346
Epoch   4 | Temp: 0.1921 | Loss: 3.2618 | Accuracy: 0.1813
Epoch   5 | Temp: 0.1902 | Loss: 3.1701 | Accuracy: 0.2338
Epoch   6 | Temp: 0.1883 | Loss: 3.0725 | Accuracy: 0.2761
Epoch   7 | Temp: 0.1864 | Loss: 2.9669 | Accuracy: 0.3023
Epoch   8 | Temp: 0.1845 | Loss: 2.8766 | Accuracy: 0.3194
Epoch   9 | Temp: 0.1827 | Loss: 2.8219 | Accuracy: 0.3482
Epoch  10 | Temp: 0.1809 | Loss: 2.7781 | Accuracy: 0.3800
Epoch  11 | Temp: 0.1791 | Loss: 2.7480 | Accuracy: 0.4024
Epoch  12 | Temp: 0.1773 | Loss: 2.7202 | Accuracy: 0.4144
Epoch  13 | Temp: 0.1755 | Loss: 2.6967 | Accuracy: 0.4294
Epoch  14 | Temp: 0.1737 | Loss: 2.6714 | Accuracy: 0.4392
Epoch  15 | Temp: 0.1720 | Loss: 2.6532 | Accuracy: 0.4414
Epoch  16 | Temp: 0.1703 | Loss: 2.6296 | Accuracy: 0.4472
Epoch  17 | Temp: 0.1686 | Loss: 2.6066 | Accuracy: 0.44