# TransNet: CSI Feedback Compression (Transformer Baseline)

**Note:** 기존 체크포인트가 GitHub 코드와 호환되지 않아 재학습이 필요합니다.  
CPU에서는 매우 느리므로, Colab GPU 사용을 권장합니다.

In [None]:
import os, sys

# --- Colab / Local 자동 감지 ---
IN_COLAB = 'google.colab' in sys.modules or os.path.exists('/content')

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=False)
    TRANSNET_ROOT = "/content/drive/MyDrive/MambaCompression/TransNet-master"
    DATA_DIR      = "/content/drive/MyDrive/MambaCompression/MambaIC/data"
else:
    TRANSNET_ROOT = r"G:\내 드라이브\MambaCompression\TransNet-master"
    DATA_DIR      = r"G:\내 드라이브\MambaCompression\MambaIC\data"

os.chdir(TRANSNET_ROOT)

!{sys.executable} -m pip install thop scipy --quiet

import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"CWD: {os.getcwd()}")

## Training

### Hyperparameters
- **d_model:** 64 (논문 기본값), **nhead:** 2, **num_layers:** 2, **dim_feedforward:** 2048
- **Scheduler:** Cosine Annealing (warmup 30 epochs, lr: 2e-4 → 5e-5)
- **Epochs:** 1000 (reference), batch_size=200
- **CR:** 1/4, 1/8, 1/16, 1/32, 1/64

### Reference Performance (논문)
| CR | Indoor NMSE (dB) | Outdoor NMSE (dB) | FLOPs |
|----|-----------------|-------------------|-------|
| 1/4 | -29.22 | -13.99 | 35.72M |
| 1/8 | -21.62 | -9.57 | 34.70M |
| 1/16 | -14.98 | -6.90 | 34.14M |
| 1/32 | -9.83 | -3.77 | 33.88M |
| 1/64 | -5.77 | -2.20 | 33.75M |

In [None]:
import os, subprocess, time, shutil, torch as _t

EPOCHS = 1000
D_MODEL = 64
SCENARIO = 'out'
BATCH_SIZE = 500
WORKERS = 2

# 이전 체크포인트 d_model 불일치 시에만 삭제
for cr in [4, 16]:
    old_dir = f'./checkpoints/cr{cr}_{SCENARIO}'
    last_path = os.path.join(old_dir, 'last.pth')
    if os.path.exists(last_path):
        _ckpt = _t.load(last_path, map_location='cpu', weights_only=False)
        ckpt_d = _ckpt['state_dict']['predict.weight'].shape[0]
        if ckpt_d != D_MODEL:
            print(f"  [CLEAN] d_model mismatch ({ckpt_d} vs {D_MODEL}), removing {old_dir}")
            shutil.rmtree(old_dir)
        else:
            print(f"  [OK] CR=1/{cr} checkpoint d_model={ckpt_d} matches, keeping for resume")

for cr in [4, 16]:
    print(f"\n{'='*60}")
    print(f"  Training: Outdoor CR=1/{cr}, {EPOCHS} epochs, d_model={D_MODEL}, bs={BATCH_SIZE}")
    print(f"{'='*60}")
    
    save_dir = f'./checkpoints/cr{cr}_{SCENARIO}'
    best_path = os.path.join(save_dir, 'best_nmse.pth')
    last_path = os.path.join(save_dir, 'last.pth')
    
    # Skip if fully trained
    if os.path.exists(last_path):
        _ckpt = _t.load(last_path, map_location='cpu', weights_only=False)
        if _ckpt['epoch'] >= EPOCHS:
            print(f"  [SKIP] Already done (epoch {_ckpt['epoch']})")
            continue
    
    # Resume from last.pth if exists
    resume_flag = ''
    if os.path.exists(last_path):
        _ckpt = _t.load(last_path, map_location='cpu', weights_only=False)
        print(f"  [RESUME] from epoch {_ckpt['epoch']}")
        resume_flag = f' --resume "{last_path}"'
    
    t0 = time.time()
    cmd = (
        f'{sys.executable} main.py'
        f' --data-dir "{DATA_DIR}"'
        f' --scenario {SCENARIO}'
        f' --batch-size {BATCH_SIZE}'
        f' --workers {WORKERS}'
        f' --cr {cr}'
        f' --d_model {D_MODEL}'
        f' --epochs {EPOCHS}'
        f' --scheduler cosine'
        f'{resume_flag}'
    )
    print(f"  CMD: {cmd}\n", flush=True)
    
    proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
    for line in proc.stdout:
        print(line, end='', flush=True)
    proc.wait()
    
    elapsed = time.time() - t0
    if proc.returncode != 0:
        print(f"\n  [ERROR] exit code {proc.returncode}")
    print(f"\n  Elapsed: {elapsed/60:.1f} min")

In [None]:
import subprocess, time

EPOCHS = 1000
D_MODEL = 64
SCENARIO = 'in'
BATCH_SIZE = 200

for cr in [4, 8, 16, 32, 64]:
    print(f"\n{'='*60}")
    print(f"  Training: Indoor CR=1/{cr}, {EPOCHS} epochs, d_model={D_MODEL}")
    print(f"{'='*60}")
    
    save_dir = f'./checkpoints/cr{cr}_{SCENARIO}'
    if os.path.exists(os.path.join(save_dir, 'best_nmse.pth')):
        print(f"  [SKIP] {save_dir}/best_nmse.pth already exists")
        continue
    
    t0 = time.time()
    cmd = (
        f'{sys.executable} main.py'
        f' --data-dir "{DATA_DIR}"'
        f' --scenario {SCENARIO}'
        f' --batch-size {BATCH_SIZE}'
        f' --workers 0'
        f' --cr {cr}'
        f' --d_model {D_MODEL}'
        f' --epochs {EPOCHS}'
        f' --scheduler cosine'
    )
    print(f"  CMD: {cmd}\n")
    result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
    elapsed = time.time() - t0
    
    print(result.stdout[-1000:] if len(result.stdout) > 1000 else result.stdout)
    if result.returncode != 0:
        print(f"STDERR: {result.stderr[-500:]}")
    print(f"\n  Elapsed: {elapsed/60:.1f} min")

## Evaluation

학습된 체크포인트로 평가합니다. `checkpoints/cr{CR}_{scenario}/best_nmse.pth` 경로 사용.

In [None]:
import subprocess, re

D_MODEL = 64
results_out = {}

for cr in [4, 16]:
    ckpt = f'./checkpoints/cr{cr}_out/best_nmse.pth'
    if not os.path.exists(ckpt):
        print(f"  [SKIP] CR=1/{cr}: {ckpt} not found")
        continue
    
    cmd = (
        f'{sys.executable} main.py'
        f' --data-dir "{DATA_DIR}"'
        f' --scenario out'
        f' --pretrained "{ckpt}"'
        f' --evaluate'
        f' --batch-size 200'
        f' --workers 0'
        f' --cr {cr}'
        f' --d_model {D_MODEL}'
    )
    print(f"\n{'='*60}")
    print(f"  Outdoor CR=1/{cr}")
    print(f"{'='*60}")
    result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
    print(result.stdout[-500:] if len(result.stdout) > 500 else result.stdout)
    if result.returncode != 0:
        print("STDERR:", result.stderr[-500:])
    m = re.search(r'NMSE:\s*([-\d.e+]+)', result.stdout)
    if m:
        results_out[cr] = float(m.group(1))

print(f"\n{'='*60}")
print(f"  Outdoor Summary")
print(f"{'='*60}")
ref = {4: -13.99, 16: -6.90}
print(f"{'CR':>6} | {'Measured':>10} | {'Reference':>10} | {'Delta':>8}")
print('-' * 45)
for cr in [4, 16]:
    meas = results_out.get(cr, float('nan'))
    print(f"  1/{cr:<3} | {meas:>10.2f} | {ref[cr]:>10.2f} | {meas - ref[cr]:>+8.2f}")

In [None]:
import subprocess, re

D_MODEL = 64
results_in = {}

for cr in [4, 8, 16, 32, 64]:
    ckpt = f'./checkpoints/cr{cr}_in/best_nmse.pth'
    if not os.path.exists(ckpt):
        print(f"  [SKIP] CR=1/{cr}: {ckpt} not found")
        continue
    
    cmd = (
        f'{sys.executable} main.py'
        f' --data-dir "{DATA_DIR}"'
        f' --scenario in'
        f' --pretrained "{ckpt}"'
        f' --evaluate'
        f' --batch-size 200'
        f' --workers 0'
        f' --cr {cr}'
        f' --d_model {D_MODEL}'
    )
    print(f"\n{'='*60}")
    print(f"  Indoor CR=1/{cr}")
    print(f"{'='*60}")
    result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
    print(result.stdout[-500:] if len(result.stdout) > 500 else result.stdout)
    if result.returncode != 0:
        print("STDERR:", result.stderr[-500:])
    m = re.search(r'NMSE:\s*([-\d.e+]+)', result.stdout)
    if m:
        results_in[cr] = float(m.group(1))

print(f"\n{'='*60}")
print(f"  Indoor Summary")
print(f"{'='*60}")
ref_in = {4: -29.22, 8: -21.62, 16: -14.98, 32: -9.83, 64: -5.77}
print(f"{'CR':>6} | {'Measured':>10} | {'Reference':>10} | {'Delta':>8}")
print('-' * 45)
for cr in [4, 8, 16, 32, 64]:
    meas = results_in.get(cr, float('nan'))
    print(f"  1/{cr:<3} | {meas:>10.2f} | {ref_in[cr]:>10.2f} | {meas - ref_in[cr]:>+8.2f}")