## Artifact Removal Transformer

### Model Overview

In [None]:
import sys

sys.path.insert(0, './FirstMultiModel/EEGART')
from tf_model import make_model
from torchinfo import summary

# model = make_model(30, 30, N=2)
# print(summary(model, input_size=[(32, 30, 120),(32, 30, 120),(32, 120, 120),(32, 120, 120)], col_names=["input_size", "output_size", "num_params", "params_percent", "kernel_size"]))

### Huggingface Config

##### Save Pre-train Config

In [2]:
from tf_config import ARTConfig, ARTEncoder_CLSConfig

art_config = ARTConfig(src_channel_size=30, tgt_channel_size=30, N=2)
art_config.save_pretrained("test_config-art")

artcls_config = ARTEncoder_CLSConfig(src_channel_size=30, tgt_channel_size=2, N=2)
artcls_config.save_pretrained("artcls-config")

In [None]:
test_config = ARTConfig.from_pretrained("test_config-art")
print(test_config)

artcls_config = ARTEncoder_CLSConfig.from_pretrained("artcls-config")
print(artcls_config)


##### Save Pre-train Model Weight

In [4]:
from tf_model import ARTModel, ARTCLSModel, ART_CLS_PreTrain
import torch

art_model = ARTModel(test_config)
cls_model = ARTCLSModel(artcls_config)
cls_pretrain = ART_CLS_PreTrain(artcls_config)
# resumeLoc = './ART/model/ART/modelsave/checkpoint.pth.tar'
# # 2. load model
# checkpoint = torch.load(resumeLoc)
# art_model.model.load_state_dict(checkpoint['state_dict'])

# art_model.save_pretrained('test_config-art')

In [5]:
from transformers import AutoConfig, AutoModel

AutoConfig.register("ART", ARTConfig)
AutoModel.register(ARTConfig, ARTModel)

AutoConfig.register("ARTEncoder_CLSConfig", ARTEncoder_CLSConfig)
AutoModel.register(ARTEncoder_CLSConfig, ARTCLSModel)

In [None]:
from transformers import AutoModel

model = AutoModel.from_pretrained('test_config-art')
# 加載目標模型
target_model = AutoModel.from_pretrained('artcls-config')



### Extract weight of Module

In [None]:
from transformers import AutoModel

# 加載來源模型
source_model = AutoModel.from_pretrained('test_config-art')

# 提取 Encoder 權重 (假設 Encoder 存在於 source_model.encoder 中)
encoder_weights = source_model.model.encoder.state_dict()
src_expandcov_weights = source_model.model.src_embed.state_dict()

# 加載目標模型
target_model = ARTCLSModel(artcls_config)

# 將 Encoder 的權重加載到目標模型的 Encoder
# 提取 Encoder 權重 (假設 Encoder 存在於 source_model.encoder 中)
target_model.model.encoder.load_state_dict(encoder_weights)
target_model.model.src_embed.load_state_dict(src_expandcov_weights)

print("Encoder weights successfully transferred!")
target_model.save_pretrained('artcls-config')



In [None]:
from tf_model import ARTCLSModel
from torchinfo import summary

# model = ARTCLSModel.from_pretrained('artcls-config')

print(summary(art_model, input_size=[(32, 30, 1024),(32, 30, 1024),(32,1024,1024),(32,1024,1024)], col_names=["input_size", "output_size", "num_params",  "params_percent", "kernel_size"]))

# print(summary(cls_model, input_size=[(32, 30, 1024),(32,1024,1024)], col_names=["input_size", "output_size", "num_params",  "params_percent", "kernel_size"]))

In [None]:
import sys

sys.path.insert(0, './FirstMultiModel/EEGART')
from tf_model import make_model
from torchinfo import summary

from tf_config import ARTConfig, ARTEncoder_CLSConfig
from tf_model import ARTModel, ARTCLSModel, ART_CLS_PreTrain
import torch
from torch.nn import CrossEntropyLoss

art_config = ARTConfig(src_channel_size=30, tgt_channel_size=30, N=2)
art_config.save_pretrained("test_config-art")

artcls_config = ARTEncoder_CLSConfig(src_channel_size=30, tgt_channel_size=2, N=2)
artcls_config.save_pretrained("artcls-config")

test_config = ARTConfig.from_pretrained("test_config-art")
artcls_config = ARTEncoder_CLSConfig.from_pretrained("artcls-config")
# print(test_config)
# print(artcls_config)

art_model = ARTModel(test_config)
cls_model = ARTCLSModel(artcls_config)
cls_pretrain = ART_CLS_PreTrain(artcls_config)

# 模擬輸入數據
src = torch.randn(32, 30, 1024)  # shape: (32, 30, 1024)
src_mask = torch.randn(32, 1024, 1024)  # shape: (32, 1024, 1024)
# label = torch.randint(0, 2, (32,))
label = torch.randn(32, 30, 1024)

# 假設你的設備是 GPU
device = torch.device("cpu" if torch.cuda.is_available() else "cpu")

# 移動數據到 GPU
src = src.to(device)
src_mask = src_mask.to(device)
cls_pretrain = cls_pretrain.to(device)
label = label.to(device)

""" ART Classifier Test """
# output = cls_model(src, None, return_dict = True)
# logits = output.last_hidden_state.squeeze(dim=1)  # shape: [32, 2]
# print(output.last_hidden_state.shape)
# loss_fct = CrossEntropyLoss()
# loss = loss_fct(logits, label)
# print(loss)

# loss = cls_pretrain(src, None, label)
# print(loss.loss)

""" ART Test """
output = art_model(src=src, tgt=src, src_mask=None, tgt_mask=None, labels=label, return_dict = True)
print(output.loss)

### Trainer 

##### Simulate Dataset

In [7]:
import torch
from torch.utils.data import Dataset, DataLoader

# 模擬自定義 ART Classifier Dataset
class MockDataset(Dataset):
    def __init__(self, num_samples, seq_len, input_dim, num_classes):
        self.num_samples = num_samples
        self.seq_len = seq_len
        self.input_dim = input_dim
        self.num_classes = num_classes
        
        # 隨機生成數據
        self.data = torch.randn(num_samples, seq_len, input_dim)  # 模擬 src
        self.masks = torch.randn(num_samples, input_dim, input_dim)  # 模擬 src_mask
        self.labels = torch.randint(0, num_classes, (num_samples,))  # 模擬 label

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return {
            "src": self.data[idx], 
            "src_mask": self.masks[idx],
            "label": self.labels[idx]
        }

# 模擬數據集參數
train_dataset = MockDataset(num_samples=1000, seq_len=30, input_dim=1024, num_classes=2)
eval_dataset = MockDataset(num_samples=200, seq_len=30, input_dim=1024, num_classes=2)



In [2]:
import torch
from torch.utils.data import Dataset, DataLoader

# 模擬自定義 ART Model Dataset
class MockDataset(Dataset):
    def __init__(self, num_samples, seq_len, input_dim):
        self.num_samples = num_samples
        self.seq_len = seq_len
        self.input_dim = input_dim
        
        # 隨機生成數據
        self.data = torch.randn(num_samples, seq_len, input_dim)  # 模擬 src
        self.masks = torch.randn(num_samples, input_dim, input_dim)  # 模擬 src_mask
        self.labels = torch.randn(num_samples, seq_len, input_dim)  # 模擬 label

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return {
            "src": self.data[idx], 
            "tgt": self.data[idx], 
            "src_mask": self.masks[idx],
            "tgt_mask": self.masks[idx],
            "label": self.labels[idx]
        }

# 模擬數據集參數
train_dataset = MockDataset(num_samples=100, seq_len=30, input_dim=1024)
eval_dataset = MockDataset(num_samples=20, seq_len=30, input_dim=1024)


In [None]:
# 檢查數據集中的一個樣本
sample = train_dataset[0]
print("Sample src shape:", sample["src"].shape)       # (30, 1024)
print("Sample src_mask shape:", sample["src_mask"].shape)  # (1024, 1024)
print("Sample label:", sample["label"])              # 標籤值


#### ART Trainer

In [None]:
from transformers import Trainer, TrainingArguments
import torch
import torch.nn as nn
import numpy as np

# 自定义数据整理器
class SignalDataCollator:
    def __call__(self, features):
        inputs = torch.stack([f["src"] for f in features])
        masks = torch.stack([f["src_mask"] for f in features])
        labels = torch.stack([f["label"] for f in features])
        return_dict = True
        return {"src": inputs, 
                "tgt":inputs, 
                "src_mask": masks, 
                "tgt_mask": masks, 
                "labels": labels, 
                "return_dict": return_dict}


# 自定义评价指标
def compute_metrics(eval_preds):
    predictions, targets = eval_preds
    mse = ((predictions - targets) ** 2).mean()
    return {"mse": mse}

# 训练参数
training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
)

# 初始化模型和 Trainer
trainer = Trainer(
    model=art_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=SignalDataCollator(),
    compute_metrics=compute_metrics,
)

# 开始训练
trainer.train()


#### CLS Trainer

In [None]:
from transformers import Trainer, TrainingArguments
import numpy as np
import evaluate 

# metric = evaluate.load("glue", "mrpc")

# def compute_metrics(eval_preds):
#     logits, labels = eval_preds
#     predictions = np.argmax(logits, axis=-1)
#     return metric.compute(predictions=predictions, references=labels)

# 訓練參數
training_args = TrainingArguments(
    output_dir="./results",       # 儲存模型的目錄
    eval_strategy="epoch",  # 替换 evaluation_strategy
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    # save_strategy="epoch",       # 每個 epoch 保存一次模型
    # logging_dir="./logs",        # 日誌目錄
    # logging_steps=10,
)
# training_args = TrainingArguments("test-trainer", eval_strategy="epoch")

# 創建 Trainer
trainer = Trainer(
    model=art_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    # compute_metrics=compute_metrics,
)

trainer.train()
