Self Supervisation for User Localisation

In [None]:
import os
import numpy as np
import h5py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from sklearn.model_selection import train_test_split

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data loading function
def get_data(data_file):
    with h5py.File(data_file, 'r') as f:
        H_Re = f['H_Re'][:]
        H_Im = f['H_Im'][:]
        SNR = f['SNR'][:]
        Pos = f['Pos'][:] if 'Pos' in f else None
    return H_Re, H_Im, SNR, Pos

# Load and concatenate labeled or unlabeled data
def load_data(files, labeled=True):
    data_list = [get_data(f) for f in files]
    H_Re = np.concatenate([data[0] for data in data_list])
    H_Im = np.concatenate([data[1] for data in data_list])
    SNR = np.concatenate([data[2] for data in data_list])
    Pos = np.concatenate([data[3] for data in data_list]) if labeled else None
    return H_Re, H_Im, SNR, Pos

# Prepare file paths
labeled_files = [f"labelled_data/file_{i}.hdf5" for i in range(1, 5)]
unlabeled_files = [f"unlabelled_data/file_{i}.hdf5" for i in range(1, 10)]

# Load labeled and unlabeled data
H_Re_labeled, H_Im_labeled, SNR_labeled, Pos_labeled = load_data(labeled_files, labeled=True)
H_Re_unlabeled, H_Im_unlabeled, SNR_unlabeled, _ = load_data(unlabeled_files, labeled=False)

# Combine and reshape data
def prepare_data(H_Re, H_Im):
    X = np.concatenate([H_Re, H_Im], axis=1)
    return X.reshape(X.shape[0], -1)

X_labeled = prepare_data(H_Re_labeled, H_Im_labeled)
X_unlabeled = prepare_data(H_Re_unlabeled, H_Im_unlabeled)

print(f"Labeled data shape: {X_labeled.shape}")
print(f"Unlabeled data shape: {X_unlabeled.shape}")
print(f"Position data shape: {Pos_labeled.shape}")

# Autoencoder definition
class Autoencoder(nn.Module):
    def __init__(self, input_dim, encoding_dim, dropout_rate=0.2):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 2048), nn.ReLU(), nn.Dropout(dropout_rate),
            nn.Linear(2048, 1024), nn.ReLU(), nn.Dropout(dropout_rate),
            nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(dropout_rate),
            nn.Linear(512, encoding_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(encoding_dim, 512), nn.ReLU(), nn.Dropout(dropout_rate),
            nn.Linear(512, 1024), nn.ReLU(),
            nn.Linear(1024, 2048), nn.ReLU(), nn.Dropout(dropout_rate),
            nn.Linear(2048, input_dim)
        )

    def forward(self, x):
        return self.decoder(self.encoder(x))

# Set up and train autoencoder
input_dim = X_unlabeled.shape[1]
encoding_dim = 256
autoencoder = Autoencoder(input_dim, encoding_dim).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(autoencoder.parameters())
scheduler = StepLR(optimizer, step_size=5, gamma=0.1)

batch_size = 64
epochs = 25
X_unlabeled_tensor = torch.FloatTensor(X_unlabeled).to(device)

for epoch in range(epochs):
    for i in range(0, X_unlabeled_tensor.size(0), batch_size):
        batch = X_unlabeled_tensor[i:i + batch_size]
        optimizer.zero_grad()
        loss = criterion(autoencoder(batch), batch)
        loss.backward()
        optimizer.step()
    scheduler.step()
    print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')

# Extract features from labeled data using the encoder
X_labeled_encoded = autoencoder.encoder(torch.FloatTensor(X_labeled).to(device)).detach().cpu().numpy()

# Define Residual Block and Position Predictor
class ResidualBlock(nn.Module):
    def __init__(self, in_features, out_features):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Linear(in_features, out_features),
            nn.BatchNorm1d(out_features),
            nn.ReLU(),
            nn.Linear(out_features, out_features),
            nn.BatchNorm1d(out_features)
        )
        self.shortcut = nn.Linear(in_features, out_features) if in_features != out_features else nn.Identity()

    def forward(self, x):
        return nn.ReLU()(self.block(x) + self.shortcut(x))

class PositionPredictor(nn.Module):
    def __init__(self, input_dim):
        super(PositionPredictor, self).__init__()
        self.model = nn.Sequential(
            ResidualBlock(input_dim, 512),
            ResidualBlock(512, 256),
            ResidualBlock(256, 128),
            nn.Linear(128, 3)
        )

    def forward(self, x):
        return self.model(x)

# Train Position Predictor
X_train, X_val, y_train, y_val = train_test_split(X_labeled_encoded, Pos_labeled, test_size=0.2, random_state=42)
predictor = PositionPredictor(encoding_dim).to(device)
optimizer = optim.AdamW(predictor.parameters(), lr=0.001, weight_decay=0.01)

for epoch in range(500):
    for i in range(0, X_train.shape[0], batch_size):
        batch_X = torch.FloatTensor(X_train[i:i+batch_size]).to(device)
        batch_y = torch.FloatTensor(y_train[i:i+batch_size]).to(device)
        optimizer.zero_grad()
        loss = criterion(predictor(batch_X), batch_y)
        loss.backward()
        optimizer.step()
    
    if (epoch + 1) % 10 == 0:
        predictor.eval()
        with torch.no_grad():
            val_loss = criterion(predictor(torch.FloatTensor(X_val).to(device)), torch.FloatTensor(y_val).to(device))
        predictor.train()
        print(f'Epoch [{epoch+1}/2000], Train Loss: {loss.item():.4f}, Val Loss: {val_loss.item():.4f}')

# Test Data Loading, Processing, and Prediction
test_files = [f"test/file_{i}.hdf5" for i in range(1, 2)]
H_Re_test, H_Im_test, _, _ = load_data(test_files, labeled=False)
X_test = prepare_data(H_Re_test, H_Im_test)
X_test_tensor = torch.FloatTensor(X_test).to(device)

autoencoder.eval()
predictor.eval()
with torch.no_grad():
    X_test_encoded = autoencoder.encoder(X_test_tensor)
    predictions = predictor(X_test_encoded).cpu().numpy()

# Prepare submission
submission = np.column_stack((np.arange(predictions.shape[0]), predictions))
np.savetxt('submission.csv', submission, delimiter=',', header='id,x,y,z', comments='', fmt=['%d', '%.6f', '%.6f', '%.6f'])

print("Submission file created: submission.csv")
