In [1]:
import torch
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


# 数据生成

In [2]:
import torch
import h5py
import numpy as np

class DataGenerator:
  def __init__(self, Nt, Nr, c, c_sub, number, max_value, data_root=None):
    self.Nt = Nt
    self.Nr = Nr
    self.c = c
    self.c_sub = c_sub
    self.number = number
    self.max_value = max_value
    self.data_root = data_root
    
  def generate(self):
    raw_data = self.max_value*np.random.uniform(-1, 1, (self.number, self.c_sub*self.c, self.Nt*self.Nr))
    return raw_data

  def fromDataRoot(self):
    mat_transpose = h5py.File(self.data_root)['Final'][:self.c_sub*self.number]
    mat = mat_transpose.T.reshape(self.Nt, self.Nr, self.c, self.c_sub, self.number)
    mat = mat.transpose(-1, 0, 1, 2, 3)
    
    mat_flat = mat.flatten()
    mat_real = np.array([mat_flat[i][0] for i in range(mat_flat.shape[0])])
    mat_img = np.array([mat_flat[i][1] for i in range(mat_flat.shape[0])])
    
    mat_real.resize((self.number, self.Nt, self.Nr, self.c, self.c_sub))
    mat_img.resize((self.number, self.Nt, self.Nr, self.c, self.c_sub))
    
    return mat_real

In [3]:
# raw data example
g = DataGenerator(16, 16, 5, 512, 100, 0.001)
raw_data = g.generate()

# 数据集和数据加载器

In [4]:
from torch.utils.data import Dataset

class MimoDataset(Dataset):
  def __init__(self, raw_data_array):
    self.raw_data_array = torch.tensor(raw_data_array, dtype=torch.float32)

  def __len__(self):
    return self.raw_data_array.shape[0]
  
  def __getitem__(self, ind):
    src, target = self.raw_data_array[ind], self.raw_data_array.clone()[ind]
    return src, target

In [5]:
mimo_data = MimoDataset(raw_data_array=raw_data)

In [6]:
from torch.utils.data import random_split
from torch.utils.data import DataLoader

class MimoDataLoader:
  def __init__(self, dataset, batch_size=16, train_rho=0.8, val_rho=0.1):
    self.dataset = dataset
    self.dataset_len = len(self.dataset)
    self.batch_size = batch_size
    self.train_num = int(train_rho*self.dataset_len)
    self.val_num = int(val_rho*self.dataset_len)
    self.test_num = self.dataset_len - self.train_num - self.val_num
  
  def getDataLoader(self):
    train_dataset, val_dataset, test_dataset, = random_split(self.dataset, [self.train_num, self.val_num, self.test_num])
    
    train_dataloader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=True)
    
    return train_dataloader, test_dataloader, val_dataloader

In [7]:
m = MimoDataLoader(mimo_data)
train_dataloader, test_dataloader, val_dataloader = m.getDataLoader()

# 模型定义

In [8]:

# 范式 1
class FrequenceNet(nn.Module):
  def __init__(self, extract_model, restore_model, input_dim, compress_ratio):
    super().__init__()
    self.extract_model = extract_model
    self.restore_model = restore_model
    assert not input_dim%compress_ratio, ValueError('compress_ratio must be divisible by input_dim.')
    self.compress_dim = int(input_dim//compress_ratio)
    self.full_connect_encoder = nn.Linear(input_dim, self.compress_dim)
    self.full_connect_decoder = nn.Linear(self.compress_dim, input_dim)
    
    for m in self.modules():
      if isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight)
        nn.init.constant_(m.bias, 0)     

  def forward(self, data):
    extract_feature = self.extract_model(data)
    compress_feature = self.full_connect_encoder(extract_feature)
    restore_feature_raw = self.full_connect_decoder(compress_feature)
    restore_feature = self.restore_model(restore_feature_raw)
    return restore_feature

In [9]:
# 范式 2
class ExtractionStrategies:
    def __init__(self, input_dim, compress_ratio) -> None:
        self.input_dim = input_dim
        self.compress_ratio = compress_ratio
        self.compress_dim = int(input_dim//compress_ratio)
        
    def equidistantExtraction(self):
        return lambda x: x[:, : :self.compress_ratio, :].clone()
    

class ChoseNet(nn.Module):
    def __init__(self, strategy_f, restore_model, input_num, restore_num):
        super().__init__()
        self.strategy_f = strategy_f
        self.restore_model = restore_model
        self.full_connect_decoder = nn.Linear(input_num, restore_num)
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)     
        
    def forward(self, data):
        data = self.strategy_f(data)
        restore_feature_raw = self.full_connect_decoder(data.transpose(-1, -2)).transpose(-1, -2)
        restore_feature = self.restore_model(restore_feature_raw)
        return restore_feature

In [10]:
# Get restore model
ES = ExtractionStrategies(512, 4)
restore_model = nn.Linear(512, 512)
chose_model = ChoseNet(ES.equidistantExtraction(), restore_model=restore_model, input_num= int(512*5/4), restore_num=2560)

In [11]:
# Get frequence_model
input_dim = 512
extract_model = nn.Linear(input_dim, input_dim)
frequence_model = FrequenceNet(extract_model, extract_model, input_dim, compress_ratio=4)

In [12]:
# Check params should be updated during training
for name, param in frequence_model.named_parameters():
    if param.requires_grad:
        print(name)
print()
print('*'*80)
print()
for name, param in chose_model.named_parameters():
    if param.requires_grad:
        print(name)

extract_model.weight
extract_model.bias
full_connect_encoder.weight
full_connect_encoder.bias
full_connect_decoder.weight
full_connect_decoder.bias

********************************************************************************

restore_model.weight
restore_model.bias
full_connect_decoder.weight
full_connect_decoder.bias


In [13]:
# See the parameters
for name, param in chose_model.named_parameters():
    print(name, param)

# for name, param in frequence_model.named_parameters():
#     print(name, param)

restore_model.weight Parameter containing:
tensor([[ 0.0111,  0.0230,  0.0243,  ...,  0.0272,  0.1175, -0.0493],
        [-0.0325, -0.0670, -0.0608,  ..., -0.0586,  0.0368,  0.0015],
        [ 0.0793,  0.0475, -0.0697,  ...,  0.0247, -0.0256,  0.0415],
        ...,
        [ 0.0003,  0.0112, -0.0159,  ...,  0.0244,  0.0035,  0.0395],
        [ 0.0259, -0.0424,  0.0493,  ...,  0.0829,  0.0299,  0.0089],
        [-0.0220, -0.0447, -0.0053,  ..., -0.0040,  0.0327,  0.0020]],
       requires_grad=True)
restore_model.bias Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.,

# 训练器

In [14]:
import torch
import numpy as np
import math

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Trainer:
    def __init__(
        self,
        model,
        epoch,
        train_dataloader,
        val_dataloader,
        criterion,
        optimizer,
        lr_scheduler,
        res_root='./res',
        checkpoint_path='./res/checkpoint/',
        loss_path='./res/loss/',
        check_frequence=0,
        training_step=1,
        valid_step=1,
        
    ):
        self.model = model
        self.model.to(device)
        self.epoch = epoch
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
        self.training_step = training_step
        self.valid_step = valid_step
        self.criterion = criterion
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.check_frequence = check_frequence
        self.checkpoint_path = checkpoint_path
        self.log = lambda x: 10*math.log10(x)
        self.loss = {
            'train_mse':[],
            'val_mse':[],
        }
        self.res_root = res_root
        self.loss_path = loss_path
    
    def train(self):
        for e in range(self.epoch):
            self.train_epoch()
            self.valid_epoch()
            print('Epoch:{:n}, Average training loss: {:3f}, Average valid loss: {:3f} \n'.format(
                int(e + 1),
                self.loss['train_mse'][-1], 
                self.loss['val_mse'][-1]
                )
            )
            
            self.lr_scheduler.step()
        
    def train_epoch(self):
        self.model.train()
        running_loss = []
        
        for ind, data in enumerate(self.train_dataloader):
            src, tag = data[0].to(device), data[1].to(device)
            
            self.optimizer.zero_grad()
            out = self.model(src)
            loss = self.criterion(out, tag, reduction='sum')
            loss.backward()
            
            train_batch_mse = self.log(loss.item()/src.shape[0])
            running_loss.append(train_batch_mse)
            self.optimizer.step()

            if ind == self.training_step:
                break
        
        epoch_loss = np.mean(running_loss)
        
        self.loss['train_mse'].append(epoch_loss)

    def valid_epoch(self):
        self.model.eval()
        running_loss = []
        
        with torch.no_grad():
            for ind, data in enumerate(self.val_dataloader):
                src, tag = data[0].to(device), data[1].to(device)
                out = self.model(src)
                loss = self.criterion(out, tag, reduction='sum')
                running_loss.append(self.log(loss.item()/src.shape[0]))

                if ind == self.valid_step:
                    break
        
        epoch_loss = np.mean(running_loss)     
        self.loss['val_mse'].append(epoch_loss)

In [15]:
from torch.nn.functional import mse_loss
from torch.optim.lr_scheduler import LambdaLR
import torch.optim as optim

criterion = mse_loss
max_epoch = 100

In [16]:
optimizer_1 = optim.Adam(params=chose_model.parameters(), lr=1e-2)
scheduler_1 = LambdaLR(
    optimizer=optimizer_1,
    lr_lambda= lambda x: (max_epoch - x)/max_epoch,
    verbose = False
)

optimizer_2 = optim.Adam(params=frequence_model.parameters(), lr=1e-2)
scheduler_2 = LambdaLR(
    optimizer=optimizer_2,
    lr_lambda= lambda x: (max_epoch - x)/max_epoch,
    verbose = False
)

In [17]:
trainer1 = Trainer(
    model=chose_model,
    epoch=100,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    training_step=1,
    valid_step=1,
    criterion=criterion,
    optimizer=optimizer_1,
    lr_scheduler=scheduler_1,
    check_frequence=5,
    checkpoint_path='.',
    loss_path='.')

trainer2 = Trainer(
    model=frequence_model,
    epoch=100,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    training_step=1,
    valid_step=1,
    criterion=criterion,
    optimizer=optimizer_2,
    lr_scheduler=scheduler_2,
    check_frequence=5,
    checkpoint_path='.',
    loss_path='.')

In [18]:
trainer1.train()

RuntimeError: CUDA error: CUBLAS_STATUS_INVALID_VALUE when calling `cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)`

In [None]:
trainer2.train()

Epoch:1, Average training loss: 7.992342, Average valid loss: 16.217878 

Epoch:2, Average training loss: 16.153667, Average valid loss: 20.800424 

Epoch:3, Average training loss: 22.739703, Average valid loss: 28.191642 

Epoch:4, Average training loss: 24.027857, Average valid loss: 22.136681 

Epoch:5, Average training loss: 20.003230, Average valid loss: 22.925304 

Epoch:6, Average training loss: 17.853768, Average valid loss: 22.381251 

Epoch:7, Average training loss: 18.757588, Average valid loss: 16.361514 

Epoch:8, Average training loss: 17.824463, Average valid loss: 14.942323 

Epoch:9, Average training loss: 12.872041, Average valid loss: 15.134530 

Epoch:10, Average training loss: 15.009959, Average valid loss: 10.707622 

Epoch:11, Average training loss: 9.959765, Average valid loss: 11.807482 

Epoch:12, Average training loss: 11.835893, Average valid loss: 10.111846 

Epoch:13, Average training loss: 9.547379, Average valid loss: 9.340590 

Epoch:14, Average trainin

## 测试

In [None]:
frequence_model.eval()

FrequenceNet(
  (extract_model): Linear(in_features=256, out_features=256, bias=True)
  (restore_model): Linear(in_features=256, out_features=256, bias=True)
  (full_connect_encoder): Linear(in_features=256, out_features=64, bias=True)
  (full_connect_decoder): Linear(in_features=64, out_features=256, bias=True)
)

In [None]:
frequence_model

FrequenceNet(
  (extract_model): Linear(in_features=256, out_features=256, bias=True)
  (restore_model): Linear(in_features=256, out_features=256, bias=True)
  (full_connect_encoder): Linear(in_features=256, out_features=64, bias=True)
  (full_connect_decoder): Linear(in_features=64, out_features=256, bias=True)
)

In [None]:
test_src, test_tag= next(iter(test_dataloader))
test_src = test_src.to(device)
test_tag = test_tag.to(device)
test_src.is_cuda

True

In [None]:
model_out = frequence_model(test_src)

In [None]:
def simple_norm(tensor):
    max_value, min_value = torch.max(tensor), torch.min(tensor)
    norm_F = lambda x: (x - min_value)/abs(max_value)
    tensor.apply_(norm_F)

In [None]:
norm_out = torch.nn.functional.normalize(model_out, p=2, dim=0)
norm_src = torch.nn.functional.normalize(test_src, p=2, dim=0)

In [None]:
print(torch.max(model_out))

tensor(0.0020, device='cuda:0', grad_fn=<MaxBackward1>)


In [None]:
test_src.shape

torch.Size([10, 1280, 256])