## Artifact Removal Transformer

### Model Overview

In [1]:
import sys

sys.path.insert(0, './FirstMultiModel/EEGART')
from tf_model import make_model
from torchinfo import summary

# model = make_model(30, 30, N=2)
# print(summary(model, input_size=[(32, 30, 120),(32, 30, 120),(32, 120, 120),(32, 120, 120)], col_names=["input_size", "output_size", "num_params", "params_percent", "kernel_size"]))

  from .autonotebook import tqdm as notebook_tqdm


### Huggingface Config

##### Save Pre-train Config

In [2]:
from tf_config import ARTConfig, ARTEncoder_CLSConfig

art_config = ARTConfig(src_channel_size=30, tgt_channel_size=30, N=2)
art_config.save_pretrained("test_config-art")

artcls_config = ARTEncoder_CLSConfig(src_channel_size=30, tgt_channel_size=2, N=2)
artcls_config.save_pretrained("artcls-config")

In [3]:
test_config = ARTConfig.from_pretrained("test_config-art")
print(test_config)

artcls_config = ARTEncoder_CLSConfig.from_pretrained("artcls-config")
print(artcls_config)


ARTConfig {
  "N": 2,
  "d_ff": 2048,
  "d_model": 128,
  "dropout": 0.1,
  "h": 8,
  "model_type": "ART",
  "src_channel_size": 30,
  "tgt_channel_size": 30,
  "transformers_version": "4.46.1"
}

ARTEncoder_CLSConfig {
  "N": 2,
  "d_ff": 2048,
  "d_model": 128,
  "dropout": 0.1,
  "h": 8,
  "model_type": "ARTEncoder_CLSConfig",
  "src_channel_size": 30,
  "tgt_channel_size": 2,
  "time_len": 1024,
  "transformers_version": "4.46.1"
}



##### Save Pre-train Model Weight

In [4]:
from tf_model import ARTModel, ARTCLSModel, ART_CLS_PreTrain
import torch

art_model = ARTModel(test_config)
cls_model = ARTCLSModel(artcls_config)
cls_pretrain = ART_CLS_PreTrain(artcls_config)
# resumeLoc = './ART/model/ART/modelsave/checkpoint.pth.tar'
# # 2. load model
# checkpoint = torch.load(resumeLoc)
# art_model.model.load_state_dict(checkpoint['state_dict'])

# art_model.save_pretrained('test_config-art')

  nn.init.xavier_uniform(p)


In [5]:
from transformers import AutoConfig, AutoModel

AutoConfig.register("ART", ARTConfig)
AutoModel.register(ARTConfig, ARTModel)

AutoConfig.register("ARTEncoder_CLSConfig", ARTEncoder_CLSConfig)
AutoModel.register(ARTEncoder_CLSConfig, ARTCLSModel)

In [None]:
from transformers import AutoModel

model = AutoModel.from_pretrained('test_config-art')
# 加載目標模型
target_model = AutoModel.from_pretrained('artcls-config')



### Extract weight of Module

In [None]:
from transformers import AutoModel

# 加載來源模型
source_model = AutoModel.from_pretrained('test_config-art')

# 提取 Encoder 權重 (假設 Encoder 存在於 source_model.encoder 中)
encoder_weights = source_model.model.encoder.state_dict()
src_expandcov_weights = source_model.model.src_embed.state_dict()

# 加載目標模型
target_model = ARTCLSModel(artcls_config)

# 將 Encoder 的權重加載到目標模型的 Encoder
# 提取 Encoder 權重 (假設 Encoder 存在於 source_model.encoder 中)
target_model.model.encoder.load_state_dict(encoder_weights)
target_model.model.src_embed.load_state_dict(src_expandcov_weights)

print("Encoder weights successfully transferred!")
target_model.save_pretrained('artcls-config')



In [None]:
from tf_model import ARTCLSModel
from torchinfo import summary

# model = ARTCLSModel.from_pretrained('artcls-config')

# print(summary(art_model, input_size=[(32, 30, 1024),(32, 30, 1024),(32,1024,1024),(32,1024,1024)], col_names=["input_size", "output_size", "num_params",  "params_percent", "kernel_size"]))

print(summary(cls_model, input_size=[(32, 30, 1024),(32,1024,1024)], col_names=["input_size", "output_size", "num_params",  "params_percent", "kernel_size"]))

In [6]:
import torch
from torch.nn import CrossEntropyLoss

# 模擬輸入數據
src = torch.randn(32, 30, 1024)  # shape: (32, 30, 1024)
src_mask = torch.randn(32, 1024, 1024)  # shape: (32, 1024, 1024)
label = torch.randint(0, 2, (32,))

# 假設你的設備是 GPU
device = torch.device("cpu" if torch.cuda.is_available() else "cpu")

# 移動數據到 GPU
src = src.to(device)
src_mask = src_mask.to(device)
cls_pretrain = cls_pretrain.to(device)
label = label.to(device)

output = cls_model(src, None)
logits = output.last_hidden_state.squeeze(dim=1)  # shape: [32, 2]
# print(output.last_hidden_state.shape)
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits, label)

print(loss)

loss = cls_pretrain(src, None, label)

print(loss.loss)

tensor(0.7003, grad_fn=<NllLossBackward0>)
tensor(0.7265, grad_fn=<NllLossBackward0>)


In [None]:
from transformers import AutoProcessor, AutoModelForCausalLM
import requests
from PIL import Image

processor = AutoProcessor.from_pretrained("microsoft/git-base-coco")
model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco")

model.generate

### Trainer 

##### Simulate Dataset

In [7]:
import torch
from torch.utils.data import Dataset, DataLoader

# 模擬自定義 Dataset
class MockDataset(Dataset):
    def __init__(self, num_samples, seq_len, input_dim, num_classes):
        self.num_samples = num_samples
        self.seq_len = seq_len
        self.input_dim = input_dim
        self.num_classes = num_classes
        
        # 隨機生成數據
        self.data = torch.randn(num_samples, seq_len, input_dim)  # 模擬 src
        self.masks = torch.randn(num_samples, input_dim, input_dim)  # 模擬 src_mask
        self.labels = torch.randint(0, num_classes, (num_samples,))  # 模擬 label

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return {
            "src": self.data[idx], 
            "src_mask": self.masks[idx],
            "label": self.labels[idx]
        }

# 模擬數據集參數
train_dataset = MockDataset(num_samples=1000, seq_len=30, input_dim=1024, num_classes=2)
eval_dataset = MockDataset(num_samples=200, seq_len=30, input_dim=1024, num_classes=2)



In [8]:
# 檢查數據集中的一個樣本
sample = train_dataset[0]
print("Sample src shape:", sample["src"].shape)       # (30, 1024)
print("Sample src_mask shape:", sample["src_mask"].shape)  # (1024, 1024)
print("Sample label:", sample["label"])              # 標籤值


Sample src shape: torch.Size([30, 1024])
Sample src_mask shape: torch.Size([1024, 1024])
Sample label: tensor(0)


In [20]:
from transformers import Trainer, TrainingArguments
import numpy as np
import evaluate 

metric = evaluate.load("glue", "mrpc")

def compute_metrics(eval_preds):
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

# # 訓練參數
# training_args = TrainingArguments(
#     output_dir="./results",       # 儲存模型的目錄
#     eval_strategy="epoch",  # 替换 evaluation_strategy
#     per_device_train_batch_size=16,
#     per_device_eval_batch_size=16,
#     num_train_epochs=3,
#     save_strategy="epoch",       # 每個 epoch 保存一次模型
#     logging_dir="./logs",        # 日誌目錄
#     logging_steps=10,
# )
training_args = TrainingArguments("test-trainer", eval_strategy="epoch")

# 創建 Trainer
trainer = Trainer(
    model=cls_pretrain,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
)

trainer.train()


 33%|███▎      | 124/375 [00:05<00:11, 21.85it/s]
 34%|███▍      | 127/375 [00:06<00:22, 10.85it/s]

{'eval_loss': 0.746819019317627, 'eval_accuracy': 0.515, 'eval_f1': 0.5570776255707762, 'eval_runtime': 0.4518, 'eval_samples_per_second': 442.655, 'eval_steps_per_second': 55.332, 'epoch': 1.0}


 67%|██████▋   | 250/375 [00:12<00:05, 21.60it/s]
 67%|██████▋   | 253/375 [00:12<00:11, 10.76it/s]

{'eval_loss': 0.7530338168144226, 'eval_accuracy': 0.525, 'eval_f1': 0.5581395348837209, 'eval_runtime': 0.4595, 'eval_samples_per_second': 435.272, 'eval_steps_per_second': 54.409, 'epoch': 2.0}


 99%|█████████▉| 373/375 [00:18<00:00, 21.71it/s]
100%|██████████| 375/375 [00:18<00:00, 19.81it/s]

{'eval_loss': 0.75572669506073, 'eval_accuracy': 0.515, 'eval_f1': 0.5488372093023256, 'eval_runtime': 0.4333, 'eval_samples_per_second': 461.569, 'eval_steps_per_second': 57.696, 'epoch': 3.0}
{'train_runtime': 18.9306, 'train_samples_per_second': 158.474, 'train_steps_per_second': 19.809, 'train_loss': 0.5835503743489583, 'epoch': 3.0}





TrainOutput(global_step=375, training_loss=0.5835503743489583, metrics={'train_runtime': 18.9306, 'train_samples_per_second': 158.474, 'train_steps_per_second': 19.809, 'total_flos': 0.0, 'train_loss': 0.5835503743489583, 'epoch': 3.0})