##Drive Connection

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


##Importing Dataset

In [None]:
import zipfile
import os
zip_file_path = '/content/drive/MyDrive/Final_Year_Project.zip'
extract_path = '/content/Final_Year_Project'

with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

print("Contents extracted to:", extract_path)
print(os.listdir(extract_path))


Contents extracted to: /content/Final_Year_Project
['Final_Year_Project']


##Installing Dependencies

In [None]:
!pip install thop
!pip uninstall torch torchvision torchaudio -y
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

Collecting thop
  Downloading thop-0.1.1.post2209072238-py3-none-any.whl.metadata (2.7 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->thop)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->thop)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->thop)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->thop)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch->thop)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch->thop)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2

##Dataset Generation



In [None]:
import numpy as np
from scipy.linalg import toeplitz, pinv, sqrtm
from scipy.special import jv
import matplotlib.pyplot as plt
from tqdm import tqdm
import h5py
import yaml
import torch
from torch.utils.data import Dataset
from typing import Dict, List, Optional
import pandas as pd
import os

# Set device to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

try:
    import qd_realization
except ImportError:
    print("Warning: qd_realization not installed - using enhanced fallback channel model")

# ======================== Configuration ========================
class Config:
    # System Parameters
    NUM_BS_ANTENNAS = 32  # Changed to 32 antennas
    NUM_UE = 4           # Changed to 4 users
    NUM_SUBCARRIERS = 64
    CP_LENGTH = 16
    PILOT_RATIO = 0.1

    # Simulation Parameters
    NUM_SAMPLES = 120000
    SNR_DB_RANGE = np.linspace(0, 20, 5)  # SNR range from 0 to 20 dB
    DOPPLER_RANGE = [1, 5, 20, 100]
    VELOCITY_RANGE = [3, 60]

    # Hardware Impairments
    PHASE_NOISE_DEG = 5
    IQ_IMBALANCE = 0.03
    DAC_QUANTIZATION_BITS = 1
    DAC_NOISE_VAR = 0.01
    PA_NONLINEARITY_ALPHA = 2.0

    # 3GPP Channel
    SCENARIO = 'UMi'
    CARRIER_FREQ = 2.4e9
    POLARIZATION = 'dual'

# ======================== 1-Bit Quantization ========================
def quantize_1bit(x: np.ndarray) -> np.ndarray:
    return (np.sign(np.real(x)) + 1j * np.sign(np.imag(x))) / np.sqrt(2)

def apply_dac_effects(x: np.ndarray) -> np.ndarray:
    x_quantized = quantize_1bit(x)
    noise = np.sqrt(Config.DAC_NOISE_VAR/2) * (
        np.random.randn(*x.shape) + 1j*np.random.randn(*x.shape))
    return x_quantized + noise

# ======================== Enhanced 3GPP Channel Model ========================
class ThreeGPPChannel:
    def __init__(self):
        self.channel = None
        if 'qd_realization' in globals():
            self.channel = qd_realization.TDL(Config.SCENARIO, Config.CARRIER_FREQ)

    def generate(self, t: float, velocity: float) -> np.ndarray:
        if self.channel:
            H = self.channel.generate(Config.NUM_BS_ANTENNAS, Config.NUM_UE, velocity)
        else:
            H = self._generate_enhanced_fallback(t, velocity)

        if Config.POLARIZATION == 'dual':
            H = np.stack([H, 0.3*H + np.random.randn(*H.shape)*0.1], axis=-1)
        return H

    def _generate_enhanced_fallback(self, t: float, velocity: float) -> np.ndarray:
        num_taps = 8
        delays = np.sort(np.random.uniform(0, 300e-9, num_taps))
        velocity = float(velocity)
        doppler = velocity * Config.CARRIER_FREQ / 3e8
        time_phase = 2 * np.pi * doppler * t * np.random.uniform(-1, 1, num_taps)

        gains = np.exp(-delays/100e-9 + 1j*time_phase)[:,None,None] * (
            np.random.randn(num_taps, Config.NUM_UE, Config.NUM_BS_ANTENNAS) +
            1j*np.random.randn(num_taps, Config.NUM_UE, Config.NUM_BS_ANTENNAS))

        H_freq = np.zeros((Config.NUM_SUBCARRIERS, Config.NUM_UE, Config.NUM_BS_ANTENNAS), dtype=complex)
        for sc in range(Config.NUM_SUBCARRIERS):
            phase = -2*np.pi*sc/Config.NUM_SUBCARRIERS * delays
            H_freq[sc] = np.sum(gains * np.exp(1j*phase[:,None,None]), axis=0)

        corr = 0.7
        R = toeplitz(corr**np.arange(Config.NUM_BS_ANTENNAS))
        return H_freq @ sqrtm(R)

# ======================== Advanced 1-Bit Precoding ========================
def mmse_1bit_precoder(H: np.ndarray, s: np.ndarray) -> np.ndarray:
    """Fixed MMSE precoder with robust dimension handling"""
    # Ensure H has shape (num_ue, num_bs_antennas)
    if H.shape[0] == Config.NUM_BS_ANTENNAS and H.shape[1] == Config.NUM_UE:
        H = H.T  # Transpose if dimensions are swapped

    # Verify final dimensions
    assert H.shape == (Config.NUM_UE, Config.NUM_BS_ANTENNAS), \
        f"Channel matrix must be (num_ue, num_bs), got {H.shape}"
    assert s.shape == (Config.NUM_UE,), \
        f"Symbol vector must be (num_ue,), got {s.shape}"

    # Regularized pseudo-inverse with proper dimensions
    regularization = 1e-3 * np.eye(Config.NUM_BS_ANTENNAS)
    W = pinv(H.conj().T @ H + regularization) @ H.conj().T @ s.reshape(-1, 1)

    # Quantization and power normalization
    W_1bit = quantize_1bit(W.flatten())
    return W_1bit / np.sqrt(np.mean(np.abs(W_1bit)**2))  # Unit power normalization

def apply_pa_nonlinearity(x: np.ndarray) -> np.ndarray:
    x_abs = np.abs(x)
    return x * (1 - np.exp(-x_abs**Config.PA_NONLINEARITY_ALPHA)) / (x_abs + 1e-6)

# ======================== OFDM Processing ========================
class OFDM:
    @staticmethod
    def modulate(x_freq: np.ndarray) -> np.ndarray:
        x_time = np.fft.ifft(x_freq, axis=0)
        x_time_cp = np.concatenate([x_time[-Config.CP_LENGTH:], x_time], axis=0)
        x_dac = apply_dac_effects(x_time_cp)
        return apply_pa_nonlinearity(x_dac)

    @staticmethod
    def demodulate(y_time: np.ndarray) -> np.ndarray:
        y_data = y_time[Config.CP_LENGTH:Config.CP_LENGTH+Config.NUM_SUBCARRIERS]
        return np.fft.fft(y_data, axis=0)

# ======================== Dataset Generation ========================
class MassiveMIMO1BitDataset(Dataset):
    def __init__(self, file_path: str):
        with h5py.File(file_path, 'r') as f:
            self.channels = torch.tensor(np.array(f['channels']), dtype=torch.complex64).to(device)
            self.precoders = torch.tensor(np.array(f['precoders_1bit']), dtype=torch.complex64).to(device)
            self.received_signals = torch.tensor(np.array(f['received_signals']), dtype=torch.complex64).to(device)

    def __len__(self):
        return len(self.channels)

    def __getitem__(self, idx):
        return {
            'channel': self.channels[idx],
            'precoder': self.precoders[idx],
            'received_signal': self.received_signals[idx]
        }

def generate_symbols() -> np.ndarray:
    return (2*np.random.randint(0, 2, (Config.NUM_UE,)) - 1) + \
           1j*(2*np.random.randint(0, 2, (Config.NUM_UE,)) - 1)

def save_to_csv(results: dict, filename: str):
    data = []

    for i in tqdm(range(len(results['transmitted_symbols'])), desc="Saving to CSV"):
        row = {
            'sample_index': i,
            'velocity': np.random.uniform(*Config.VELOCITY_RANGE),
            'snr_db': np.random.choice(Config.SNR_DB_RANGE)
        }

        # Transmitted symbols
        for ue in range(Config.NUM_UE):
            sym = results['transmitted_symbols'][i][ue]
            row[f'symbol_ue{ue}_real'] = np.real(sym)
            row[f'symbol_ue{ue}_imag'] = np.imag(sym)

        # Precoder data (first 10 antennas for CSV)
        precoder = results['precoders_1bit'][i]
        for ant in range(min(10, Config.NUM_BS_ANTENNAS)):
            row[f'precoder_ant{ant}_real'] = np.real(precoder[ant])
            row[f'precoder_ant{ant}_imag'] = np.imag(precoder[ant])

        # Received signal (first 5 samples)
        rx_signal = results['received_signals'][i]
        for t in range(min(5, len(rx_signal))):
            row[f'rx_time{t}_real'] = np.real(rx_signal[t])
            row[f'rx_time{t}_imag'] = np.imag(rx_signal[t])

        data.append(row)

    df = pd.DataFrame(data)
    df.to_csv(filename, index=False)
    print(f"Saved CSV data to {filename} (shape: {df.shape})")

def run_simulation():
    channel_model = ThreeGPPChannel()
    results = {
        'channels': [],
        'precoders_1bit': [],
        'received_signals': [],
        'transmitted_symbols': []
    }

    # Pre-allocate arrays for better performance
    results['channels'] = np.zeros((Config.NUM_SAMPLES, Config.NUM_SUBCARRIERS, Config.NUM_UE, Config.NUM_BS_ANTENNAS, 2), dtype=np.complex64)
    results['precoders_1bit'] = np.zeros((Config.NUM_SAMPLES, Config.NUM_BS_ANTENNAS), dtype=np.complex64)
    results['received_signals'] = np.zeros((Config.NUM_SAMPLES, Config.NUM_SUBCARRIERS + Config.CP_LENGTH, Config.NUM_BS_ANTENNAS), dtype=np.complex64)
    results['transmitted_symbols'] = np.zeros((Config.NUM_SAMPLES, Config.NUM_UE), dtype=np.complex64)

    for t in tqdm(range(Config.NUM_SAMPLES), desc="Generating samples"):
        # Channel generation
        velocity = np.random.uniform(*Config.VELOCITY_RANGE)
        H_true = channel_model.generate(t/1e3, velocity)

        # Symbol generation
        symbols = generate_symbols()

        # 1-bit precoding with proper dimension handling
        H_center = H_true[Config.NUM_SUBCARRIERS//2]
        if len(H_center.shape) == 3:  # If polarization dimension exists
            H_center = H_center[..., 0]  # Take first polarization
        if H_center.shape[0] == Config.NUM_BS_ANTENNAS:
            H_center = H_center.T
        W_1bit = mmse_1bit_precoder(H_center, symbols)

        # OFDM transmission
        x_freq = np.zeros((Config.NUM_SUBCARRIERS, Config.NUM_BS_ANTENNAS), dtype=complex)
        x_freq[Config.NUM_SUBCARRIERS//2] = W_1bit.T
        x_time = OFDM.modulate(x_freq)

        # AWGN channel
        noise_var = 10**(-np.random.choice(Config.SNR_DB_RANGE)/10)
        y_time = x_time + np.sqrt(noise_var/2) * (
            np.random.randn(*x_time.shape) + 1j*np.random.randn(*x_time.shape))

        # Store results
        results['channels'][t] = H_true
        results['precoders_1bit'][t] = W_1bit
        results['received_signals'][t] = y_time
        results['transmitted_symbols'][t] = symbols

    # Save to HDF5
    h5_filename = 'dataset_1bit.h5'
    with h5py.File(h5_filename, 'w') as f:
        for key, val in results.items():
            f.create_dataset(key, data=np.array(val), compression='gzip')

    # Save to CSV
    csv_filename = 'dataset_1bit.csv'
    save_to_csv(results, csv_filename)

    return results

if __name__ == "__main__":
    # Clear GPU cache
    torch.cuda.empty_cache()

    results = run_simulation()
    dataset = MassiveMIMO1BitDataset('dataset_1bit.h5')

    print("\n=== Dataset Summary ===")
    print(f"Total samples: {len(dataset)}")
    print(f"Channel shape: {dataset[0]['channel'].shape}")
    print(f"Precoder shape: {dataset[0]['precoder'].shape}")
    print(f"Received signal shape: {dataset[0]['received_signal'].shape}")
    print("\n=== Validation Checks ===")
    print(f"1-bit constraint: {np.all(np.unique(np.real(dataset[0]['precoder'].cpu().numpy()) == [-1/np.sqrt(2), 1/np.sqrt(2)]))}")
    print(f"Avg precoder power: {np.mean(np.abs(dataset[0]['precoder'].cpu().numpy())**2):.4f}")
    print(f"CSV file created: {os.path.exists('dataset_1bit.csv')}")

##Model

In [3]:
import torch
import torch.nn as nn
import numpy as np

class EfficientDepthwiseConv(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=3,
                                 stride=stride, padding=1, groups=in_channels)
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.norm = nn.GroupNorm(4, out_channels)
        self.act = nn.SiLU()

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        x = self.norm(x)
        return self.act(x)

class LightweightAttention(nn.Module):
    def __init__(self, channels, reduction_ratio=8):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction_ratio),
            nn.ReLU(),
            nn.Linear(channels // reduction_ratio, channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y

class OptimizedResBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = EfficientDepthwiseConv(in_ch, out_ch)
        self.attn = LightweightAttention(out_ch)
        self.conv2 = EfficientDepthwiseConv(out_ch, out_ch)

        self.shortcut = nn.Sequential()
        if in_ch != out_ch:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, bias=False),
                nn.GroupNorm(4, out_ch)
            )

    def forward(self, x):
        residual = self.shortcut(x)
        x = self.conv1(x)
        x = self.attn(x)
        x = self.conv2(x)
        return x + residual

class PowerAwareQuantizer(nn.Module):
    def __init__(self):
        super().__init__()
        self.beta = nn.Parameter(torch.tensor(1.0))
        self.sqrt2 = np.sqrt(2)

    def forward(self, x):
        x_scaled = x / (self.beta * self.sqrt2)
        x_clipped = torch.clamp(x_scaled, -1, 1)
        quantized = torch.sign(x_clipped) / self.sqrt2
        # Power normalization to ensure E[|x|²] = 1
        return quantized * torch.sqrt(torch.tensor(2.0))  # Corrects power to 1.0


class UltraEfficientPrecoder(nn.Module):
    def __init__(self, num_subcarriers=64, num_bs_antennas=32):
        super().__init__()
        self.M = num_bs_antennas

        # Input processing for complex channels
        # Expects input shape: [batch, 4, num_subcarriers, num_bs_antennas, 2]
        self.input_processor = nn.Sequential(
            nn.Conv2d(4*2, 32, kernel_size=3, padding=1),  # Process real and imaginary parts
            nn.GroupNorm(4, 32),
            nn.SiLU()
        )

        # Downsampling
        self.conv2 = EfficientDepthwiseConv(32, 64, stride=2)

        # Residual blocks
        self.block1 = OptimizedResBlock(64, 64)
        self.block2 = OptimizedResBlock(64, 64)

        # Calculate flattened size
        with torch.no_grad():
            dummy = torch.zeros(1, 8, num_subcarriers, num_bs_antennas)  # 4 users * 2 (real+imag)
            dummy = self.input_processor(dummy)
            dummy = self.conv2(dummy)
            dummy = self.block1(dummy)
            dummy = self.block2(dummy)
            self.flattened_size = dummy.numel() // dummy.shape[0]

        # Output layers
        self.fc = nn.Linear(self.flattened_size, 2*self.M)  # 2* for real/imag output if needed
        self.quant = PowerAwareQuantizer()

    def forward(self, x):
        # x shape: [batch, 4, num_subcarriers, num_bs_antennas, 2]
        batch_size = x.size(0)

        # Process complex inputs
        # Reshape to [batch, 4*2, num_subcarriers, num_bs_antennas]
        x = x.permute(0, 1, 4, 2, 3).reshape(batch_size, 8, x.size(2), x.size(3))

        x = self.input_processor(x)
        x = self.conv2(x)
        x = self.block1(x)
        x = self.block2(x)

        # Flatten and output
        x = x.reshape(batch_size, -1)
        x = self.fc(x)
        return self.quant(x)

# Verification Test
if __name__ == "__main__":
    # Config
    num_subcarriers = 64
    num_bs_antennas = 32
    batch_size = 4

    # Correct complex input tensor [batch, 4, num_subcarriers, num_bs_antennas, 2]
    x = torch.randn(batch_size, 4, num_subcarriers, num_bs_antennas, 2)

    # Initialize model
    model = UltraEfficientPrecoder(num_subcarriers, num_bs_antennas)

    # Forward pass
    with torch.no_grad():
        out = model(x)

    print("\n=== Verification ===")
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {out.shape}")
    print(f"Quantization levels: {torch.unique(out)}")
    print(f"Output power: {torch.mean(torch.abs(out)**2):.4f}")


=== Verification ===
Input shape: torch.Size([4, 4, 64, 32, 2])
Output shape: torch.Size([4, 64])
Quantization levels: tensor([-1.0000,  1.0000])
Output power: 1.0000


##Parameter Check

In [None]:
from thop import profile
from thop import clever_format

macs, params = profile(model, inputs=(x, ))

macs, params = clever_format([macs, params], "%.3f")
print(f"GFLOPs: {macs}, Params: {params}")

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
GFLOPs: 70.591M, Params: 2.123M


##Dataset Inspection

In [None]:
import h5py
def inspect_h5_file(filepath):
    """Prints the structure of an HDF5 file"""
    with h5py.File(filepath, 'r') as hf:
        print(f"\nStructure of {filepath}:")
        def print_attrs(name, obj):
            print(f"  {name}: shape={obj.shape if hasattr(obj, 'shape') else 'group'}")
        hf.visititems(print_attrs)

# Run this on one of your files to see the actual structure
inspect_h5_file('/content/Final_Year_Project/Final_Year_Project/Dataset/train/h5/dataset_1bit_1.h5')


Structure of /content/Final_Year_Project/Final_Year_Project/Dataset/train/h5/dataset_1bit_1.h5:
  channels: shape=(40000, 64, 4, 32, 2)
  precoders_1bit: shape=(40000, 32)
  received_signals: shape=(40000, 80, 32)
  transmitted_symbols: shape=(40000, 4)


##CSV Data Inspection

In [None]:
import pandas as pd
df = pd.read_csv('/content/drive/MyDrive/Final_Year_Project/csv/dataset_1bit(1).csv')
print("Columns:", df.columns.tolist())

Columns: ['sample_index', 'velocity', 'snr_db', 'symbol_ue0_real', 'symbol_ue0_imag', 'symbol_ue1_real', 'symbol_ue1_imag', 'symbol_ue2_real', 'symbol_ue2_imag', 'symbol_ue3_real', 'symbol_ue3_imag', 'precoder_ant0_real', 'precoder_ant0_imag', 'precoder_ant1_real', 'precoder_ant1_imag', 'precoder_ant2_real', 'precoder_ant2_imag', 'precoder_ant3_real', 'precoder_ant3_imag', 'precoder_ant4_real', 'precoder_ant4_imag', 'precoder_ant5_real', 'precoder_ant5_imag', 'precoder_ant6_real', 'precoder_ant6_imag', 'precoder_ant7_real', 'precoder_ant7_imag', 'precoder_ant8_real', 'precoder_ant8_imag', 'precoder_ant9_real', 'precoder_ant9_imag', 'rx_time0_real', 'rx_time0_imag', 'rx_time1_real', 'rx_time1_imag', 'rx_time2_real', 'rx_time2_imag', 'rx_time3_real', 'rx_time3_imag', 'rx_time4_real', 'rx_time4_imag']


In [None]:
expected = []
for i in range(4):
    expected.extend([f'symbol_ue{i}_real', f'symbol_ue{i}_imag'])
for i in range(10):
    expected.extend([f'precoder_ant{i}_real', f'precoder_ant{i}_imag'])
print("Expected:", expected)

Expected: ['symbol_ue0_real', 'symbol_ue0_imag', 'symbol_ue1_real', 'symbol_ue1_imag', 'symbol_ue2_real', 'symbol_ue2_imag', 'symbol_ue3_real', 'symbol_ue3_imag', 'precoder_ant0_real', 'precoder_ant0_imag', 'precoder_ant1_real', 'precoder_ant1_imag', 'precoder_ant2_real', 'precoder_ant2_imag', 'precoder_ant3_real', 'precoder_ant3_imag', 'precoder_ant4_real', 'precoder_ant4_imag', 'precoder_ant5_real', 'precoder_ant5_imag', 'precoder_ant6_real', 'precoder_ant6_imag', 'precoder_ant7_real', 'precoder_ant7_imag', 'precoder_ant8_real', 'precoder_ant8_imag', 'precoder_ant9_real', 'precoder_ant9_imag']


##Training and Testing

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import pandas as pd
import numpy as np
from tqdm import tqdm
import os
import gc
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from scipy.spatial.distance import cosine
import math

class ChunkedCSVDataset(Dataset):
    """Processes CSV in chunks to save memory"""
    def __init__(self, csv_paths, num_subcarriers=64, num_antennas=32, chunk_size=5000):
        self.csv_paths = csv_paths
        self.num_subcarriers = num_subcarriers
        self.num_antennas = num_antennas
        self.chunk_size = chunk_size
        self.samples = []

        # First pass to verify columns
        sample_df = pd.read_csv(csv_paths[0], nrows=1)
        self._verify_columns(sample_df.columns)

        # Process each file in chunks
        for path in csv_paths:
            chunk_reader = pd.read_csv(path, chunksize=self.chunk_size)
            for chunk in chunk_reader:
                self._process_chunk(chunk)

    def _verify_columns(self, columns):
        required = []
        for i in range(4):
            required.extend([f'symbol_ue{i}_real', f'symbol_ue{i}_imag'])
        for i in range(10):
            required.extend([f'precoder_ant{i}_real', f'precoder_ant{i}_imag'])

        missing = set(required) - set(columns)
        if missing:
            raise ValueError(f"Missing columns: {missing}")

    def _process_chunk(self, chunk):
        # Process UE symbols
        ue_symbols = np.stack([
            chunk[[f'symbol_ue{i}_real', f'symbol_ue{i}_imag']].values
            for i in range(4)
        ], axis=1)

        # Create channel data [chunk_size, 4, 64, 32, 2]
        channels = np.zeros((len(chunk), 4, self.num_subcarriers, self.num_antennas, 2))
        channels[..., 0] = ue_symbols[:, :, 0][:, :, np.newaxis, np.newaxis]
        channels[..., 1] = ue_symbols[:, :, 1][:, :, np.newaxis, np.newaxis]

        # Process precoders
        precoders = np.zeros((len(chunk), 20))
        for i in range(10):
            precoders[:, 2*i] = chunk[f'precoder_ant{i}_real'].values
            precoders[:, 2*i+1] = chunk[f'precoder_ant{i}_imag'].values

        self.samples.extend(list(zip(channels, precoders)))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        channel, precoder = self.samples[idx]
        return (
            torch.as_tensor(channel, dtype=torch.float32),
            torch.as_tensor(precoder, dtype=torch.float32)
        )

def calculate_snr(predicted, target):
    """Calculate Signal-to-Noise Ratio (SNR) in dB"""
    noise = predicted - target
    signal_power = torch.mean(torch.square(target))
    noise_power = torch.mean(torch.square(noise))
    snr = 10 * torch.log10(signal_power / noise_power)
    return snr.item()

def calculate_ber(predicted, target):
    """Calculate Bit Error Rate (BER) for 1-bit precoding"""
    predicted_bits = (predicted > 0).float()
    target_bits = (target > 0).float()
    errors = torch.sum(torch.abs(predicted_bits - target_bits))
    total_bits = target_bits.numel()
    ber = errors / total_bits
    return ber.item()

def calculate_cosine_similarity(predicted, target):
    """Calculate Cosine Similarity between predicted and target"""
    predicted = predicted.flatten().detach().cpu().numpy()
    target = target.flatten().detach().cpu().numpy()
    return 1 - cosine(predicted, target)

def train_model():
    # Configuration
    config = {
        'num_subcarriers': 64,
        'num_bs_antennas': 32,
        'batch_size': 16,  # Increased from 8
        'num_epochs': 40,  # Changed to 40 epochs
        'learning_rate': 3e-4,
        'save_dir': r'/content/drive/MyDrive/Final_year_Project/csv/precoder_checkpoints',
        'csv_files': [
            r'/content/drive/MyDrive/Final_year_Project/csv/dataset_1bit(1).csv',
            r'/content/drive/MyDrive/Final_year_Project/csv/dataset_1bit(1).csv'
        ],
        'chunk_size': 1000,
        'checkpoint_freq': 5,
        'lr_scheduler': {
            'mode': 'min',
            'factor': 0.5,
            'patience': 3,
            'threshold': 0.0001
        }
    }

    # Setup
    os.makedirs(config['save_dir'], exist_ok=True)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Initialize enhanced model
    model = UltraEfficientPrecoder(
        num_subcarriers=config['num_subcarriers'],
        num_bs_antennas=config['num_bs_antennas']
    ).to(device)

    optimizer = optim.AdamW(model.parameters(), lr=config['learning_rate'])
    criterion = nn.MSELoss()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode=config['lr_scheduler']['mode'],
        factor=config['lr_scheduler']['factor'],
        patience=config['lr_scheduler']['patience'],
        threshold=config['lr_scheduler']['threshold']
    )

    # Track metrics
    metrics = {
        'train': {'loss': [], 'snr': [], 'ber': [], 'cosine': []},
        'val': {'loss': [], 'snr': [], 'ber': [], 'cosine': []},
        'test': {'loss': 0, 'snr': 0, 'ber': 0, 'cosine': 0}
    }

    try:
        # Load dataset
        dataset = ChunkedCSVDataset(
            config['csv_files'],
            num_subcarriers=config['num_subcarriers'],
            num_antennas=config['num_bs_antennas'],
            chunk_size=config['chunk_size']
        )

        # Split dataset
        train_size = int(0.8 * len(dataset))
        val_size = int(0.1 * len(dataset))
        test_size = len(dataset) - train_size - val_size

        train_dataset, val_dataset, test_dataset = random_split(
            dataset,
            [train_size, val_size, test_size],
            generator=torch.Generator().manual_seed(42)
        )

        print(f"Dataset sizes - Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

        # DataLoaders with increased workers
        train_loader = DataLoader(
            train_dataset,
            batch_size=config['batch_size'],
            shuffle=True,
            num_workers=4,
            pin_memory=True
        )

        val_loader = DataLoader(
            val_dataset,
            batch_size=config['batch_size'],
            shuffle=False,
            num_workers=2,
            pin_memory=True
        )

        test_loader = DataLoader(
            test_dataset,
            batch_size=config['batch_size'],
            shuffle=False,
            num_workers=2,
            pin_memory=True
        )

        # Training loop for 40 epochs
        best_val_loss = float('inf')
        for epoch in range(config['num_epochs']):
            model.train()
            epoch_train_metrics = {'loss': 0, 'snr': 0, 'ber': 0, 'cosine': 0}

            # Training phase with gradient clipping
            for inputs, targets in tqdm(train_loader, desc=f'Epoch {epoch+1}/{config["num_epochs"]} [Train]'):
                inputs = inputs.to(device, non_blocking=True)
                targets = targets.to(device, non_blocking=True)

                # Forward pass
                outputs = model(inputs)

                # Handle dimension mismatch
                if outputs.size(-1) > targets.size(-1):
                    targets = torch.cat([
                        targets,
                        torch.zeros(targets.size(0), outputs.size(-1) - targets.size(-1),
                        device=device)
                    ], dim=-1)

                loss = criterion(outputs, targets)

                # Calculate metrics
                epoch_train_metrics['loss'] += loss.item()
                epoch_train_metrics['snr'] += calculate_snr(outputs, targets)
                epoch_train_metrics['ber'] += calculate_ber(outputs, targets)
                epoch_train_metrics['cosine'] += calculate_cosine_similarity(outputs, targets)

                # Backward pass with gradient clipping
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()

            # Average training metrics
            for k in epoch_train_metrics:
                epoch_train_metrics[k] /= len(train_loader)
            metrics['train']['loss'].append(epoch_train_metrics['loss'])
            metrics['train']['snr'].append(epoch_train_metrics['snr'])
            metrics['train']['ber'].append(epoch_train_metrics['ber'])
            metrics['train']['cosine'].append(epoch_train_metrics['cosine'])

            # Validation phase
            model.eval()
            epoch_val_metrics = {'loss': 0, 'snr': 0, 'ber': 0, 'cosine': 0}
            with torch.no_grad():
                for inputs, targets in tqdm(val_loader, desc=f'Epoch {epoch+1}/{config["num_epochs"]} [Val]'):
                    inputs = inputs.to(device, non_blocking=True)
                    targets = targets.to(device, non_blocking=True)

                    outputs = model(inputs)
                    if outputs.size(-1) > targets.size(-1):
                        targets = torch.cat([
                            targets,
                            torch.zeros(targets.size(0), outputs.size(-1) - targets.size(-1),
                            device=device)
                        ], dim=-1)

                    loss = criterion(outputs, targets)

                    epoch_val_metrics['loss'] += loss.item()
                    epoch_val_metrics['snr'] += calculate_snr(outputs, targets)
                    epoch_val_metrics['ber'] += calculate_ber(outputs, targets)
                    epoch_val_metrics['cosine'] += calculate_cosine_similarity(outputs, targets)

            # Average validation metrics and update LR
            for k in epoch_val_metrics:
                epoch_val_metrics[k] /= len(val_loader)
            metrics['val']['loss'].append(epoch_val_metrics['loss'])
            metrics['val']['snr'].append(epoch_val_metrics['snr'])
            metrics['val']['ber'].append(epoch_val_metrics['ber'])
            metrics['val']['cosine'].append(epoch_val_metrics['cosine'])
            scheduler.step(epoch_val_metrics['loss'])

            print(f"\nEpoch {epoch+1}:")
            print(f"Train - Loss: {epoch_train_metrics['loss']:.4f}, SNR: {epoch_train_metrics['snr']:.2f} dB, BER: {epoch_train_metrics['ber']:.4f}, Cosine: {epoch_train_metrics['cosine']:.4f}")
            print(f"Val   - Loss: {epoch_val_metrics['loss']:.4f}, SNR: {epoch_val_metrics['snr']:.2f} dB, BER: {epoch_val_metrics['ber']:.4f}, Cosine: {epoch_val_metrics['cosine']:.4f}")
            print(f"Current LR: {optimizer.param_groups[0]['lr']:.2e}")

            # Save best model
            if epoch_val_metrics['loss'] < best_val_loss:
                best_val_loss = epoch_val_metrics['loss']
                checkpoint = {
                    'epoch': epoch + 1,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'metrics': metrics,
                    'config': config
                }
                torch.save(checkpoint, os.path.join(config['save_dir'], 'best_model.pth'))
                print("Saved new best model")

            # Save periodic checkpoint
            if (epoch + 1) % config['checkpoint_freq'] == 0:
                torch.save(checkpoint, os.path.join(config['save_dir'], f'checkpoint_epoch_{epoch+1}.pth'))

        # Test evaluation
        print("\nEvaluating on test set...")
        model.load_state_dict(torch.load(os.path.join(config['save_dir'], 'best_model.pth'))['model_state_dict'])
        model.eval()
        test_metrics = {'loss': 0, 'snr': 0, 'ber': 0, 'cosine': 0}
        with torch.no_grad():
            for inputs, targets in tqdm(test_loader, desc="Test Evaluation"):
                inputs = inputs.to(device, non_blocking=True)
                targets = targets.to(device, non_blocking=True)

                outputs = model(inputs)
                if outputs.size(-1) > targets.size(-1):
                    targets = torch.cat([
                        targets,
                        torch.zeros(targets.size(0), outputs.size(-1) - targets.size(-1),
                        device=device)
                    ], dim=-1)

                test_metrics['loss'] += criterion(outputs, targets).item()
                test_metrics['snr'] += calculate_snr(outputs, targets)
                test_metrics['ber'] += calculate_ber(outputs, targets)
                test_metrics['cosine'] += calculate_cosine_similarity(outputs, targets)

        # Average test metrics
        for k in test_metrics:
            test_metrics[k] /= len(test_loader)
        metrics['test'] = test_metrics

        print("\nFinal Test Metrics:")
        print(f"Loss: {test_metrics['loss']:.4f}, SNR: {test_metrics['snr']:.2f} dB, BER: {test_metrics['ber']:.4f}, Cosine: {test_metrics['cosine']:.4f}")

        # Save final results and plots
        results = {
            'metrics': metrics,
            'config': config,
            'best_epoch': checkpoint['epoch']
        }
        torch.save(results, os.path.join(config['save_dir'], 'final_results.pth'))

        # Plot metrics
        plt.figure(figsize=(15, 10))

        plt.subplot(2, 2, 1)
        plt.plot(metrics['train']['loss'], label='Train')
        plt.plot(metrics['val']['loss'], label='Validation')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()

        plt.subplot(2, 2, 2)
        plt.plot(metrics['train']['snr'], label='Train')
        plt.plot(metrics['val']['snr'], label='Validation')
        plt.xlabel('Epoch')
        plt.ylabel('SNR (dB)')
        plt.legend()

        plt.subplot(2, 2, 3)
        plt.plot(metrics['train']['ber'], label='Train')
        plt.plot(metrics['val']['ber'], label='Validation')
        plt.xlabel('Epoch')
        plt.ylabel('BER')
        plt.legend()

        plt.subplot(2, 2, 4)
        plt.plot(metrics['train']['cosine'], label='Train')
        plt.plot(metrics['val']['cosine'], label='Validation')
        plt.xlabel('Epoch')
        plt.ylabel('Cosine Similarity')
        plt.legend()

        plt.tight_layout()
        plt.savefig(os.path.join(config['save_dir'], 'training_metrics.png'))
        plt.close()

        print("Training completed and results saved!")

    except Exception as e:
        print(f"Error during training: {str(e)}")
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

if __name__ == '__main__':
    train_model()

Using device: cuda
Dataset sizes - Train: 64000, Val: 8000, Test: 8000


Epoch 1/50 [Train]: 100%|██████████| 8000/8000 [25:07<00:00,  5.31it/s]
Epoch 1/50 [Val]: 100%|██████████| 1000/1000 [00:05<00:00, 170.56it/s]


Epoch 1: Train Loss = 1.1558, Val Loss = 1.1550
Saved new best model


Epoch 2/50 [Train]: 100%|██████████| 8000/8000 [25:00<00:00,  5.33it/s]
Epoch 2/50 [Val]: 100%|██████████| 1000/1000 [00:06<00:00, 145.59it/s]


Epoch 2: Train Loss = 1.1556, Val Loss = 1.1546
Saved new best model


Epoch 3/50 [Train]: 100%|██████████| 8000/8000 [25:01<00:00,  5.33it/s]
Epoch 3/50 [Val]: 100%|██████████| 1000/1000 [00:06<00:00, 145.73it/s]


Epoch 3: Train Loss = 1.1555, Val Loss = 1.1546
Saved new best model


Epoch 4/50 [Train]: 100%|██████████| 8000/8000 [25:22<00:00,  5.25it/s]
Epoch 4/50 [Val]: 100%|██████████| 1000/1000 [00:05<00:00, 177.21it/s]


Epoch 4: Train Loss = 1.1555, Val Loss = 1.1545
Saved new best model


Epoch 5/50 [Train]: 100%|██████████| 8000/8000 [25:00<00:00,  5.33it/s]
Epoch 5/50 [Val]: 100%|██████████| 1000/1000 [00:06<00:00, 163.76it/s]


Epoch 5: Train Loss = 1.1556, Val Loss = 1.1544
Saved new best model


Epoch 6/50 [Train]: 100%|██████████| 8000/8000 [24:35<00:00,  5.42it/s]
Epoch 6/50 [Val]: 100%|██████████| 1000/1000 [00:06<00:00, 163.74it/s]


Epoch 6: Train Loss = 1.1556, Val Loss = 1.1545


Epoch 7/50 [Train]: 100%|██████████| 8000/8000 [25:19<00:00,  5.26it/s]
Epoch 7/50 [Val]: 100%|██████████| 1000/1000 [00:06<00:00, 152.39it/s]


Epoch 7: Train Loss = 1.1555, Val Loss = 1.1542
Saved new best model


Epoch 8/50 [Train]: 100%|██████████| 8000/8000 [25:10<00:00,  5.30it/s]
Epoch 8/50 [Val]: 100%|██████████| 1000/1000 [00:05<00:00, 172.61it/s]


Epoch 8: Train Loss = 1.1557, Val Loss = 1.1541
Saved new best model


Epoch 9/50 [Train]: 100%|██████████| 8000/8000 [25:26<00:00,  5.24it/s]
Epoch 9/50 [Val]: 100%|██████████| 1000/1000 [00:06<00:00, 157.73it/s]


Epoch 9: Train Loss = 1.1558, Val Loss = 1.1545


Epoch 10/50 [Train]:  81%|████████  | 6460/8000 [20:30<04:37,  5.55it/s]

Epoch 22: Train Loss = 1.1565, Val Loss = 1.1572
Saved new best model
Epoch 23/50 [Train]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [41:07<00:00,  3.24it/s]
Epoch 23/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [10:33<00:00,  1.58it/s]
Epoch 23: Train Loss = 1.1566, Val Loss = 1.1575
Epoch 24/50 [Train]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [45:22<00:00,  2.94it/s]
Epoch 24/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [20:47<00:00,  1.25s/it]
Epoch 24: Train Loss = 1.1564, Val Loss = 1.1576
Epoch 25/50 [Train]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [1:20:06<00:00,  1.66it/s]
Epoch 25/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [17:34<00:00,  1.05s/it]
Epoch 25: Train Loss = 1.1564, Val Loss = 1.1576
Epoch 26/50 [Train]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [52:01<00:00,  2.56it/s]
Epoch 26/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [14:42<00:00,  1.13it/s]
Epoch 26: Train Loss = 1.1564, Val Loss = 1.1575
Epoch 27/50 [Train]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [54:16<00:00,  2.46it/s]
Epoch 27/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [17:02<00:00,  1.02s/it]
Epoch 27: Train Loss = 1.1564, Val Loss = 1.1574
Epoch 28/50 [Train]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [49:34<00:00,  2.69it/s]
Epoch 28/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [13:17<00:00,  1.25it/s]
Epoch 28: Train Loss = 1.1565, Val Loss = 1.1572
Saved new best model
Epoch 29/50 [Train]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [51:09<00:00,  2.61it/s]
Epoch 29/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [19:54<00:00,  1.19s/it]
Epoch 29: Train Loss = 1.1566, Val Loss = 1.1573
Epoch 30/50 [Train]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [1:02:13<00:00,  2.14it/s]
Epoch 30/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [14:29<00:00,  1.15it/s]
Epoch 30: Train Loss = 1.1566, Val Loss = 1.1573
Epoch 31/50 [Train]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [45:00<00:00,  2.96it/s]
Epoch 31/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [11:46<00:00,  1.42it/s]
Epoch 31: Train Loss = 1.1566, Val Loss = 1.1572
Saved new best model
Epoch 32/50 [Train]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [43:49<00:00,  3.04it/s]
Epoch 32/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [11:41<00:00,  1.43it/s]
Epoch 32: Train Loss = 1.1566, Val Loss = 1.1571
Saved new best model
Epoch 33/50 [Train]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [44:19<00:00,  3.01it/s]
Epoch 33/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [12:09<00:00,  1.37it/s]
Epoch 33: Train Loss = 1.1566, Val Loss = 1.1569
Saved new best model
Epoch 34/50 [Train]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [1:45:32<00:00,  1.26it/s]
Epoch 34/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [13:59<00:00,  1.19it/s]
Epoch 34: Train Loss = 1.1566, Val Loss = 1.1567
Saved new best model
Epoch 35/50 [Train]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [50:14<00:00,  2.65it/s]
Epoch 35/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [15:35<00:00,  1.07it/s]
Epoch 35: Train Loss = 1.1566, Val Loss = 1.1568
Epoch 36/50 [Train]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [50:59<00:00,  2.62it/s]
Epoch 36/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [18:24<00:00,  1.10s/it]
Epoch 36: Train Loss = 1.1565, Val Loss = 1.1568
Epoch 37/50 [Train]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [49:05<00:00,  2.72it/s]
Epoch 37/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [12:01<00:00,  1.39it/s]
Epoch 37: Train Loss = 1.1565, Val Loss = 1.1570
Epoch 38/50 [Train]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [49:43<00:00,  2.68it/s]
Epoch 38/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [14:39<00:00,  1.14it/s]
Epoch 38: Train Loss = 1.1566, Val Loss = 1.1569
Epoch 39/50 [Train]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [45:42<00:00,  2.92it/s]
Epoch 39/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [12:58<00:00,  1.28it/s]
Epoch 39: Train Loss = 1.1566, Val Loss = 1.1569
Epoch 40/50 [Train]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [50:13<00:00,  2.65it/s]
Epoch 40/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [13:00<00:00,  1.28it/s]
Epoch 40: Train Loss = 1.1566, Val Loss = 1.1570
Epoch 41/50 [Train]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [46:11<00:00,  2.89it/s]
Epoch 41/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [14:36<00:00,  1.14it/s]
Epoch 41: Train Loss = 1.1566, Val Loss = 1.1569
Epoch 42/50 [Train]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [48:34<00:00,  2.75it/s]
Epoch 42/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [12:39<00:00,  1.32it/s]
Epoch 42: Train Loss = 1.1567, Val Loss = 1.1570
Epoch 43/50 [Train]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [46:47<00:00,  2.85it/s]
Epoch 43/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [18:43<00:00,  1.12s/it]
Epoch 43: Train Loss = 1.1568, Val Loss = 1.1572
Epoch 44/50 [Train]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [46:20<00:00,  2.88it/s]
Epoch 44/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [12:43<00:00,  1.31it/s]
Epoch 44: Train Loss = 1.1566, Val Loss = 1.1573
Epoch 45/50 [Train]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [50:12<00:00,  2.66it/s]
Epoch 45/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [16:50<00:00,  1.01s/it]
Epoch 45: Train Loss = 1.1567, Val Loss = 1.1578
Epoch 46/50 [Train]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [48:08<00:00,  2.77it/s]
Epoch 46/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [13:17<00:00,  1.25it/s]
Epoch 46: Train Loss = 1.1567, Val Loss = 1.1579
Epoch 47/50 [Train]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [52:01<00:00,  2.56it/s]
Epoch 47/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [16:20<00:00,  1.02it/s]
Epoch 47: Train Loss = 1.1565, Val Loss = 1.1578
Epoch 48/50 [Train]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [49:04<00:00,  2.72it/s]
Epoch 48/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [18:22<00:00,  1.10s/it]
Epoch 48: Train Loss = 1.1564, Val Loss = 1.1579
Epoch 49/50 [Train]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [47:53<00:00,  2.78it/s]
Epoch 49/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [13:20<00:00,  1.25it/s]
Epoch 49: Train Loss = 1.1564, Val Loss = 1.1582
Epoch 50/50 [Train]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [49:10<00:00,  2.71it/s]
Epoch 50/50 [Val]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [17:17<00:00,  1.04s/it]
Epoch 50: Train Loss = 1.1563, Val Loss = 1.1571, snr = 18.7, ber = 4.6, cosine = 0.91

Evaluating on test set...
Test Evaluation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [15:07<00:00,  1.10it/s]

Final Test Loss: 1.1539, snr = 18.2, ber = 5.0, cosine = 0.8634