In [94]:
import torch
import torch.nn as nn

# 数据生成

In [95]:
import torch
import h5py
import numpy as np

class DataGenerator:
  def __init__(self, Nt, Nr, c, c_sub, number, max_value, data_root=None):
    self.Nt = Nt
    self.Nr = Nr
    self.c = c
    self.c_sub = c_sub
    self.number = number
    self.max_value = max_value
    self.data_root = data_root
    
  def generate(self):
    raw_data = self.max_value*np.random.uniform(-1, 1, (self.number, self.c_sub*self.c, self.Nt*self.Nr))
    return raw_data

  def fromDataRoot(self):
    mat_transpose = h5py.File(self.data_root)['Final'][:self.c_sub*self.number]
    mat = mat_transpose.T.reshape(self.Nt, self.Nr, self.c, self.c_sub, self.number)
    mat = mat.transpose(-1, 0, 1, 2, 3)
    
    mat_flat = mat.flatten()
    mat_real = np.array([mat_flat[i][0] for i in range(mat_flat.shape[0])])
    mat_img = np.array([mat_flat[i][1] for i in range(mat_flat.shape[0])])
    
    mat_real.resize((self.number, self.Nt, self.Nr, self.c, self.c_sub))
    mat_img.resize((self.number, self.Nt, self.Nr, self.c, self.c_sub))
    
    return mat_real

In [96]:
# raw data example
g = DataGenerator(16, 16, 5, 256, 100, 1)
raw_data = g.generate()

# 数据集和数据加载器

In [97]:
from torch.utils.data import Dataset

class MimoDataset(Dataset):
  def __init__(self, raw_data_array):
    self.raw_data_array = torch.tensor(raw_data_array, dtype=torch.float32)

  def __len__(self):
    return self.raw_data_array.shape[0]
  
  def __getitem__(self, ind):
    src, target = self.raw_data_array[ind], self.raw_data_array.clone()[ind]
    return src, target

In [98]:
mimo_data = MimoDataset(raw_data_array=raw_data)

In [99]:
from torch.utils.data import random_split
from torch.utils.data import DataLoader

class MimoDataLoader:
  def __init__(self, dataset, batch_size=16, train_rho=0.8, val_rho=0.1):
    self.dataset = dataset
    self.dataset_len = len(self.dataset)
    self.batch_size = batch_size
    self.train_num = int(train_rho*self.dataset_len)
    self.val_num = int(val_rho*self.dataset_len)
    self.test_num = self.dataset_len - self.train_num - self.val_num
  
  def getDataLoader(self):
    train_dataset, val_dataset, test_dataset, = random_split(self.dataset, [self.train_num, self.val_num, self.test_num])
    
    train_dataloader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=True)
    
    return train_dataloader, test_dataloader, val_dataloader

In [100]:
m = MimoDataLoader(mimo_data)
train_dataloader, test_dataloader, val_dataloader = m.getDataLoader()

# 模型定义

In [101]:
class FrequenceNet(nn.Module):
  def __init__(self, extract_model, restore_model, input_dim, compress_ratio):
    super().__init__()
    self.extract_model = extract_model
    self.restore_model = restore_model
    assert not input_dim%compress_ratio, ValueError('compress_ratio must be divisible by input_dim.')
    self.compress_dim = int(input_dim//compress_ratio)
    self.full_connect_encoder = nn.Linear(input_dim, self.compress_dim)
    self.full_connect_decoder = nn.Linear(self.compress_dim, input_dim)

  def forward(self, data):
    extract_feature = self.extract_model(data)
    compress_feature = self.full_connect_encoder(extract_feature)
    restore_feature_raw = self.full_connect_decoder(compress_feature)
    restore_feature = self.restore_model(restore_feature_raw)
    return restore_feature

In [102]:
# Get frequence_model
input_dim = 256
extract_model = nn.TransformerEncoderLayer(d_model = input_dim, nhead=8, dropout=0.)
extract_model, restore_model = nn.Linear(input_dim, input_dim), nn.Linear(input_dim, input_dim)
frequence_model = FrequenceNet(extract_model, restore_model, input_dim, compress_ratio=4)

In [103]:
# Check params should be updated during training
for name, param in frequence_model.named_parameters():
    if param.requires_grad:
        print(name)

extract_model.weight
extract_model.bias
restore_model.weight
restore_model.bias
full_connect_encoder.weight
full_connect_encoder.bias
full_connect_decoder.weight
full_connect_decoder.bias


# 训练器

In [104]:
import os
import torch
import json
import numpy as np
import math

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Trainer:
    def __init__(
        self,
        model,
        epoch,
        train_dataloader,
        val_dataloader,
        training_step,
        valid_step,
        criterion,
        optimizer,
        lr_scheduler,
        check_frequence,
        model_savepath,
        checkpoint_path,
        loss_path,
        
    ):
        self.model = model
        self.model.to(device)
        self.epoch = epoch
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
        self.training_step = training_step
        self.valid_step = valid_step
        self.criterion = criterion
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.check_frequence = check_frequence
        self.model_savepath = model_savepath
        self.checkpoint_path = checkpoint_path
        self.nmse = lambda x: 10*math.log10(x)
        self.loss = {
            'train':[],
            'val':[]
        }
        self.loss_path = loss_path
    
    def train(self):
        for e in range(self.epoch):
            self.train_epoch()
            self.valid_epoch()
            print('Epoch: {:d}, Training loss: {:3f}, Valid loss: {:3f}\n'.format(
                e + 1,
                self.nmse(self.loss['train'][-1]), 
                self.nmse(self.loss['val'][-1])
                )
            )
            
            self.lr_scheduler.step()
            # if (not (e + 1)%self.check_frequence) or ((e + 1) == self.epoch):
            #     self.getCheckpoint(e + 1)
        # self.saveLoss()
        
    
    def train_epoch(self):
        self.model.train()
        running_loss = []
        
        for ind, data in enumerate(self.train_dataloader):
            src, tag = data[0].to(device), data[1].to(device)
            
            self.optimizer.zero_grad()
            out = self.model(src)
            loss = self.criterion(out, tag)
            loss.backward()
            running_loss.append(loss.item())
            self.optimizer.step()

            if ind == self.training_step:
                break
        
        epoch_loss = np.mean(running_loss)
        
        self.loss['train'].append(epoch_loss)

    def valid_epoch(self):
        self.model.eval()
        running_loss = []
        
        with torch.no_grad():
            for ind, data in enumerate(self.val_dataloader):
                src, tag = data[0].to(device), data[1].to(device)
                
                out = self.model(src)
                loss = self.criterion(out, tag)
                running_loss.append(loss.item())

                if ind == self.valid_step:
                    break
        
        epoch_loss = np.mean(running_loss)     
        self.loss['val'].append(epoch_loss)
    
    def getCheckpoint(self,):
        save_path = os.path.join(self.model_savepath, 'model.pt')
        torch.save(self.model, save_path)
    
    def saveLoss(self):
        loss_path = os.path.join(self.loss_path, 'loss.json')
        with open(loss_path, 'w') as f:
            json.dump(self.loss, f) 

In [105]:
from torch.nn.functional import mse_loss
from torch.optim.lr_scheduler import LambdaLR
import torch.optim as optim

criterion = mse_loss
optimizer = optim.Adam(params=frequence_model.parameters(),lr=1e-3)
max_epoch = 100
scheduler = LambdaLR(
    optimizer=optimizer,
    lr_lambda= lambda x: (max_epoch - x)/max_epoch,
    verbose = False
)
trainer = Trainer(
    model=frequence_model,
    epoch=100,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    training_step=1,
    valid_step=1,
    criterion=criterion,
    optimizer=optimizer,
    lr_scheduler=scheduler,
    check_frequence=5,
    model_savepath='.',
    checkpoint_path='.',
    loss_path='.')

In [106]:
trainer.train()

Epoch: 1, Training loss: -4.706240, Valid loss: -4.800242

Epoch: 2, Training loss: -4.830762, Valid loss: -4.905098

Epoch: 3, Training loss: -4.931091, Valid loss: -4.997410

Epoch: 4, Training loss: -5.022994, Valid loss: -5.085013

Epoch: 5, Training loss: -5.109554, Valid loss: -5.171806

Epoch: 6, Training loss: -5.197677, Valid loss: -5.258460

Epoch: 7, Training loss: -5.283744, Valid loss: -5.343642

Epoch: 8, Training loss: -5.367965, Valid loss: -5.425331

Epoch: 9, Training loss: -5.448783, Valid loss: -5.501887

Epoch: 10, Training loss: -5.526093, Valid loss: -5.572560

Epoch: 11, Training loss: -5.596690, Valid loss: -5.637323

Epoch: 12, Training loss: -5.659301, Valid loss: -5.696299

Epoch: 13, Training loss: -5.717804, Valid loss: -5.749465

Epoch: 14, Training loss: -5.768872, Valid loss: -5.796848

Epoch: 15, Training loss: -5.816732, Valid loss: -5.838599

Epoch: 16, Training loss: -5.857827, Valid loss: -5.874993

Epoch: 17, Training loss: -5.895559, Valid loss: 

## 测试

In [107]:
frequence_model.eval()

FrequenceNet(
  (extract_model): Linear(in_features=256, out_features=256, bias=True)
  (restore_model): Linear(in_features=256, out_features=256, bias=True)
  (full_connect_encoder): Linear(in_features=256, out_features=64, bias=True)
  (full_connect_decoder): Linear(in_features=64, out_features=256, bias=True)
)

In [136]:
frequence_model

FrequenceNet(
  (extract_model): Linear(in_features=256, out_features=256, bias=True)
  (restore_model): Linear(in_features=256, out_features=256, bias=True)
  (full_connect_encoder): Linear(in_features=256, out_features=64, bias=True)
  (full_connect_decoder): Linear(in_features=64, out_features=256, bias=True)
)

In [146]:
test_src, test_tag= next(iter(test_dataloader))
test_src = test_src.to(device)
test_tag = test_tag.to(device)
test_src.is_cuda

True

In [157]:
model_out = frequence_model(test_src)

In [None]:
def simple_norm(tensor):
    max_value, min_value = torch.max(tensor), torch.min(tensor)
    norm_F = lambda x: (x - min_value)/abs(max_value)
    tensor.apply_(norm_F)

In [177]:
norm_out = torch.nn.functional.normalize(model_out, p=2, dim=0)
norm_src = torch.nn.functional.normalize(test_src, p=2, dim=0)

In [163]:
print(torch.max(model_out))

tensor(1.4023, device='cuda:0', grad_fn=<MaxBackward1>)


In [156]:
test_src.shape

torch.Size([10, 1280, 256])