In [None]:
# ============================================================
# Cell 1: Setup — Drive mount, clone CRNet, install deps
# ============================================================
from google.colab import drive
drive.mount('/content/drive')

import os, subprocess, sys

CRNET_DIR  = '/content/CRNet'
DATA_DIR   = '/content/drive/MyDrive/MambaCompression/MambaIC/data'
CKPT_DIR   = '/content/CRNet/checkpoints'

# Clone repo
if not os.path.isdir(CRNET_DIR):
    !git clone https://github.com/Kylin9511/CRNet.git {CRNET_DIR}
else:
    print('CRNet already cloned')

os.makedirs(CKPT_DIR, exist_ok=True)

# Install deps
!pip install thop gdown -q

sys.path.insert(0, CRNET_DIR)
print('Setup done. DATA_DIR:', DATA_DIR)

In [None]:
# ============================================================
# Cell 2: Download pretrained checkpoints (Google Drive)
# Folder: https://drive.google.com/drive/folders/16hQsrxkFuyjtmW4DOI8-Tix5TP5JfIia
# Expected: in_04, in_08, in_16, in_32, in_64
#           out_04, out_08, out_16, out_32, out_64
# ============================================================
import os

FOLDER_ID = '16hQsrxkFuyjtmW4DOI8-Tix5TP5JfIia'
CKPT_DIR  = '/content/CRNet/checkpoints'

!gdown --folder https://drive.google.com/drive/folders/{FOLDER_ID} \
       --output {CKPT_DIR} --quiet

print('\nDownloaded checkpoints:')
for f in sorted(os.listdir(CKPT_DIR)):
    print(' ', f)

In [None]:
# ============================================================
# Cell 3: Performance verification — all CRs × both scenarios
# 논문 목표치:
#   indoor : CR=1/4 -17.36 | 1/8 -12.70 | 1/16 -8.65 | 1/32 -6.24 | 1/64 -5.84
#   outdoor: CR=1/4  -8.75 | 1/8  -7.61 | 1/16 -4.51 | 1/32 -2.81 | 1/64 -1.93  (CsiNet baseline)
#   CRNet-cosine indoor : -26.99 / -16.01 / -11.35 / -8.93 / -6.49
#   CRNet-cosine outdoor: -12.71 /  -8.04 /  -5.44 / -3.51 / -2.22
# ============================================================
%cd /content/CRNet

SCENARIOS = ['in', 'out']
CRS       = [4, 8, 16, 32, 64]
DATA_DIR  = '/content/drive/MyDrive/MambaCompression/MambaIC/data'
CKPT_DIR  = '/content/CRNet/checkpoints'

for scenario in SCENARIOS:
    print(f'\n{"="*55}')
    print(f'  Scenario: {"Indoor" if scenario=="in" else "Outdoor"}')
    print(f'{"="*55}')
    for cr in CRS:
        # checkpoint name: in_04, in_08, ... out_04 ...
        ckpt_name = f'{scenario}_{cr:02d}'
        ckpt_path = os.path.join(CKPT_DIR, ckpt_name)
        if not os.path.exists(ckpt_path):
            print(f'  CR=1/{cr:<2}  [SKIP] checkpoint not found: {ckpt_name}')
            continue
        print(f'\n  CR = 1/{cr}')
        !python main.py \
            --data-dir {DATA_DIR} \
            --scenario {scenario} \
            --pretrained {ckpt_path} \
            --evaluate \
            --batch-size 200 \
            --workers 2 \
            --cr {cr}

In [None]:
# ============================================================
# Cell 4: Encoder FLOPs measurement (thop)
# Encoder = encoder1 + encoder2 + encoder_conv + encoder_fc
# Input: (1, 2, 32, 32)
# ============================================================
import sys
sys.path.insert(0, '/content/CRNet')

import torch
import torch.nn as nn
from thop import profile, clever_format
from models.crnet import CRNet

class CRNetEncoderOnly(nn.Module):
    """encoder1 + encoder2 + encoder_conv + encoder_fc만 추출"""
    def __init__(self, full_model):
        super().__init__()
        self.encoder1      = full_model.encoder1
        self.encoder2      = full_model.encoder2
        self.encoder_conv  = full_model.encoder_conv
        self.encoder_fc    = full_model.encoder_fc

    def forward(self, x):
        n = x.size(0)
        e1  = self.encoder1(x)
        e2  = self.encoder2(x)
        out = torch.cat((e1, e2), dim=1)
        out = self.encoder_conv(out)
        out = self.encoder_fc(out.view(n, -1))
        return out

dummy = torch.randn(1, 2, 32, 32)

print(f'  {"CR":<6} {"MACs":>10}  {"FLOPs(=2xMAC)":>16}  {"Params":>10}')
print(f'  {"-"*50}')

for cr in [4, 8, 16, 32, 64]:
    full   = CRNet(reduction=cr)
    enc    = CRNetEncoderOnly(full).eval()
    macs, params = profile(enc, inputs=(dummy,), verbose=False)
    macs_s, params_s = clever_format([macs, params], '%.3f')
    flops_s = clever_format([macs * 2], '%.3f')[0]   # FLOPs = 2 x MACs
    print(f'  1/{cr:<4}  MACs={macs_s:>8}  FLOPs={flops_s:>8}  Params={params_s:>8}')

print('\n* 논문 보고값(encoder FLOPs):')
paper = {4: '5.12M', 8: '4.07M', 16: '3.55M', 32: '3.28M', 64: '3.16M'}
for cr, v in paper.items():
    print(f'   1/{cr}: {v}')