In [1]:
!pip install scikit-learn torch pandas numpy

Defaulting to user installation because normal site-packages is not writeable


In [2]:
import pandas as pd
df = pd.read_csv('./awgn_dataset.csv')

In [3]:
num_columns = len(df.columns)
num_columns

1153

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from sklearn.model_selection import train_test_split

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import torch
import pandas as pd
from typing import List, Tuple

class ErrorCorrectionDataset:
    """
    Dataset for AWGN error correction.
    Each sample contains:
      - X: corrupted bytes as a sequence [seq_len, 1], normalized to [-1, 1]
      - y: original bytes as class indices [seq_len]
    """
    def __init__(self, dataframe, psnr_pools: List[float], train_size: float = 0.8, seq_len=576, normalize=True):
        self.df = df
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.psnr_pools = psnr_pools
        self.train_size = train_size
        self.seq_len = seq_len
        self.normalize = normalize
        self.corrupted_cols = [f"Corrupted_{i}" for i in range(seq_len)]
        self.original_cols = [f"Original_{i}" for i in range(seq_len)]
        self.X = dataframe[[f"Corrupted_{i}" for i in range(seq_len)]].values.astype(np.uint8)
        self.Y = dataframe[[f"Original_{i}" for i in range(seq_len)]].values.astype(np.uint8)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        corrupted = self.X[idx]  # shape (576,)
        original = self.Y[idx]   # shape (576,)

        if self.normalize:
            corrupted = (corrupted.astype(np.float32) / 127.5) - 1.0
        else:
            corrupted = corrupted.astype(np.float32)

        X = torch.tensor(corrupted, dtype=torch.float32).unsqueeze(-1)  # [576, 1]
        y = torch.tensor(original, dtype=torch.long)                    # [576]
        return X, y

    def _split_tensor(self, sub_df: pd.DataFrame) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """Helper to split a dataframe into train/test torch tensors."""
        n_train = int(self.train_size * len(sub_df))
        sub_df = sub_df.sample(frac=1, random_state=None).reset_index(drop=True)

        X = sub_df[self.corrupted_cols].values.astype("float32") / 255.0
        y = sub_df[self.original_cols].values.astype("float32") / 255.0

        X_train, X_test = X[:n_train], X[n_train:]
        y_train, y_test = y[:n_train], y[n_train:]

        return (
            torch.tensor(X_train, dtype=torch.float32).to(self.device),
            torch.tensor(y_train, dtype=torch.float32).to(self.device),
            torch.tensor(X_test, dtype=torch.float32).to(self.device),
            torch.tensor(y_test, dtype=torch.float32).to(self.device),
        )

    def prepare_datasets(self) -> List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
        """
        Returns a list of datasets (X_train, y_train, X_test, y_test) for each PSNR pool.
        """
        datasets = []
        for psnr in self.psnr_pools:
            sub_df = self.df[self.df["PSNR"] == psnr]
            datasets.append(self._split_tensor(sub_df))
        return datasets

    def generalized_dataset(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Returns a single dataset (X_train, y_train, X_test, y_test) combining all PSNR pools.
        """
        return self._split_tensor(self.df)

In [7]:
psnr_pools = [0.025, 0.1, 1, 2, 5, 50]
psnr_pools.reverse()
Dataset = ErrorCorrectionDataset(df, psnr_pools, train_size=0.8)

In [8]:
error_datasets = Dataset.prepare_datasets()
generalized_dataset = Dataset.generalized_dataset()

AttributeError: 'ErrorCorrectionDataset' object has no attribute 'corrupted_cols'

In [None]:
class AWGNErrorCorrector(nn.Module):
    """
    LSTM-based error correction model.
    Takes in corrupted byte sequence and predicts original bytes.
    """
    def __init__(self, input_size=1, hidden_size=256, num_layers=2, dropout=0.1):
        super().__init__()
        self.input_fc = nn.Linear(input_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers,
                            batch_first=True, dropout=dropout)
        self.out_fc = nn.Linear(hidden_size, 256)  # 256-way classification for byte values

    def forward(self, x):
        # x: [batch, seq_len, 1]
        x_proj = torch.relu(self.input_fc(x))   # [batch, seq_len, hidden]
        out, _ = self.lstm(x_proj)              # [batch, seq_len, hidden]
        logits = self.out_fc(out)               # [batch, seq_len, 256]
        return logits

In [None]:
n_features = num_columns // 2
model = AWGNErrorCorrector(n_features).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
X_train_00 = error_datasets[0][0].unsqueeze(0)
y_train_00 = error_datasets[0][1].unsqueeze(0)

for epoch in range(100):
    optimizer.zero_grad()
    outputs = model(X_train_00)  
    loss = criterion(
        outputs.reshape(-1, 256),     
        y_train_00.reshape(-1)        
    )
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch+1}, Loss={loss.item():.6f}")

ValueError: Expected input batch_size (3968) to match target batch_size (2285568).

In [None]:
y_train_00.shape, X_train_00.shape

(torch.Size([1, 3968, 576]), torch.Size([1, 3968, 576]))

In [None]:
X_train_01 = error_datasets[1][0].unsqueeze(0)
y_train_01 = error_datasets[1][1].unsqueeze(0)

for epoch in range(80):
    optimizer.zero_grad()
    outputs = model(X_train_01)  
    loss = criterion(
        outputs.view(-1, 256),     
        y_train_01.view(-1)        
    )
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch+1}, Loss={loss.item():.6f}")

Epoch 1, Loss=0.659095
Epoch 2, Loss=0.659102
Epoch 3, Loss=0.658696
Epoch 4, Loss=0.658377
Epoch 5, Loss=0.658324
Epoch 6, Loss=0.657725
Epoch 7, Loss=0.657372
Epoch 8, Loss=0.657420
Epoch 9, Loss=0.656842
Epoch 10, Loss=0.656712
Epoch 11, Loss=0.656581
Epoch 12, Loss=0.656127
Epoch 13, Loss=0.656123
Epoch 14, Loss=0.655745
Epoch 15, Loss=0.655605
Epoch 16, Loss=0.655392
Epoch 17, Loss=0.655119
Epoch 18, Loss=0.655005
Epoch 19, Loss=0.654697
Epoch 20, Loss=0.654590
Epoch 21, Loss=0.654306
Epoch 22, Loss=0.654172
Epoch 23, Loss=0.653945
Epoch 24, Loss=0.653755
Epoch 25, Loss=0.653577
Epoch 26, Loss=0.653323
Epoch 27, Loss=0.653186
Epoch 28, Loss=0.652943
Epoch 29, Loss=0.652797
Epoch 30, Loss=0.652611
Epoch 31, Loss=0.652422
Epoch 32, Loss=0.652240
Epoch 33, Loss=0.652032
Epoch 34, Loss=0.651858
Epoch 35, Loss=0.651666
Epoch 36, Loss=0.651510
Epoch 37, Loss=0.651346
Epoch 38, Loss=0.651180
Epoch 39, Loss=0.651041
Epoch 40, Loss=0.650888
Epoch 41, Loss=0.650742
Epoch 42, Loss=0.650579
E

In [None]:
X_train_02 = error_datasets[2][0].unsqueeze(0)
y_train_02 = error_datasets[2][1].unsqueeze(0)

for epoch in range(60):
    optimizer.zero_grad()
    outputs = model(X_train_02)  
    loss = criterion(
        outputs.view(-1, 256),     
        y_train_02.view(-1)        
    )
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch+1}, Loss={loss.item():.6f}")

Epoch 1, Loss=0.677091
Epoch 2, Loss=0.677451
Epoch 3, Loss=0.677787
Epoch 4, Loss=0.677483
Epoch 5, Loss=0.676791
Epoch 6, Loss=0.676596
Epoch 7, Loss=0.676641
Epoch 8, Loss=0.676182
Epoch 9, Loss=0.675914
Epoch 10, Loss=0.675907
Epoch 11, Loss=0.675466
Epoch 12, Loss=0.675289
Epoch 13, Loss=0.675119
Epoch 14, Loss=0.674769
Epoch 15, Loss=0.674645
Epoch 16, Loss=0.674327
Epoch 17, Loss=0.674165
Epoch 18, Loss=0.673893
Epoch 19, Loss=0.673706
Epoch 20, Loss=0.673477


In [None]:
X_train_03 = error_datasets[3][0].unsqueeze(0)
y_train_03 = error_datasets[3][1].unsqueeze(0)

for epoch in range(40):
    optimizer.zero_grad()
    outputs = model(X_train_03)  
    loss = criterion(
        outputs.view(-1, 256),     
        y_train_03.view(-1)        
    )
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch+1}, Loss={loss.item():.6f}")

Epoch 1, Loss=0.690970
Epoch 2, Loss=0.690746
Epoch 3, Loss=0.690396
Epoch 4, Loss=0.690054
Epoch 5, Loss=0.689600
Epoch 6, Loss=0.689132
Epoch 7, Loss=0.688616
Epoch 8, Loss=0.688094
Epoch 9, Loss=0.687588
Epoch 10, Loss=0.687144


In [None]:
X_train_04 = error_datasets[4][0].unsqueeze(0)
y_train_04 = error_datasets[4][1].unsqueeze(0)

for epoch in range(20):
    optimizer.zero_grad()
    outputs = model(X_train_04)  
    loss = criterion(
        outputs.view(-1, 256),     
        y_train_04.view(-1)        
    )
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch+1}, Loss={loss.item():.6f}")

In [None]:
X_train_05 = error_datasets[5][0].unsqueeze(0)
y_train_05 = error_datasets[5][1].unsqueeze(0)

for epoch in range(10):
    optimizer.zero_grad()
    outputs = model(X_train_05)  
    loss = criterion(
        outputs.view(-1, 256),     
        y_train_05.view(-1)        
    )
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch+1}, Loss={loss.item():.6f}")

In [None]:
X_train = generalized_dataset[0].unsqueeze(0)
y_train = generalized_dataset[1].unsqueeze(0)

for epoch in range(20):
    optimizer.zero_grad()
    outputs = model(X_train)  
    loss = criterion(
        outputs.view(-1, 256),     
        y_train.view(-1)        
    )
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch+1}, Loss={loss.item():.6f}")

Epoch 1, Loss=0.664328
Epoch 2, Loss=0.664662
Epoch 3, Loss=0.664613
Epoch 4, Loss=0.664188
Epoch 5, Loss=0.663560
Epoch 6, Loss=0.662773
Epoch 7, Loss=0.661944
Epoch 8, Loss=0.661126
Epoch 9, Loss=0.660486
Epoch 10, Loss=0.660104
Epoch 11, Loss=0.659872
Epoch 12, Loss=0.659651
Epoch 13, Loss=0.659421
Epoch 14, Loss=0.659322
Epoch 15, Loss=0.659249
Epoch 16, Loss=0.659121
Epoch 17, Loss=0.659037
Epoch 18, Loss=0.659040
Epoch 19, Loss=0.659243
Epoch 20, Loss=0.660288


In [None]:
X_test = generalized_dataset[2].unsqueeze(0)
y_test = generalized_dataset[3].unsqueeze(0)

model.eval()
with torch.no_grad():
    # ---- Test metrics ----
    test_outputs = model(X_test)  
    test_loss = criterion(
        test_outputs.view(-1, 256),     
        y_test.view(-1)        
    )
    
    test_preds = test_outputs.argmax(dim=-1)
    test_targets = y_test
    test_acc = (test_preds == test_targets).float().mean().item()

    print(f"Test  Loss = {test_loss.item():.6f}, Test Byte Accuracy = {test_acc*100:.6f}%")

Test  Loss = 0.662494, Test Byte Accuracy = 0.500357%


In [None]:
model.eval()
with torch.no_grad():
    # ---- Train metrics ----
    train_outputs = model(X_train)
    train_loss = criterion(
        train_outputs.view(-1, 256),     
        y_train.view(-1)        
    )

    train_preds = train_outputs.argmax(dim=-1)
    train_targets = y_train
    train_acc = (train_preds == train_targets).float().mean().item()

    print(f"Train Loss = {train_loss.item():.6f}, Train Byte Accuracy = {train_acc*100:.6f}%")

Train Loss = 0.662961, Train Byte Accuracy = 0.513877%


In [None]:
import os

save_path = "./awgn_error_corrector.pth"
os.makedirs(os.path.dirname(save_path), exist_ok=True)
torch.save(model.state_dict(), save_path)
print(f"Model saved to {save_path}")

Model saved to ./error_correction/error_corrector_updated_dataset.pth
