In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torchvision.utils as vutils
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import einops
import random
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision.utils as vutils
import matplotlib.pyplot as plt
import os
from torchmetrics import StructuralSimilarityIndexMeasure

In [2]:
EPOCHS = 3000
BATCH_SIZE = 256

In [3]:
seed = 42
deterministic = True

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if deterministic:
	torch.backends.cudnn.deterministic = True
	torch.backends.cudnn.benchmark = False

In [4]:
# 1. ImageNet 데이터셋 불러오기
transform = transforms.Compose([
    transforms.Resize((64, 64)),  # 이미지 크기를 64x64로 조정
    transforms.ToTensor(),  # 텐서로 변환
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 정규화
])

# ImageNet 데이터 로드
train_data = torchvision.datasets.STL10(root='./data', split='train', download=True, transform=transform)
# 시드 설정을 위한 생성기(generator) 생성
g = torch.Generator()
g.manual_seed(42)

# DataLoader 설정
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, generator=g)
val_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=False)

Files already downloaded and verified


In [5]:
class SAAAE(nn.Module):
    def __init__(self, image_size=64, latent_dim=128):
        super(SAAAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten()
        )

        self.fc_mu = nn.Linear(512 * 4 * 4, latent_dim)  # 잠재 공간의 평균
        self.fc_logvar = nn.Linear(512 * 4 * 4, latent_dim)  # 로그 분산

        self.fc_decode = nn.Linear(latent_dim, 512 * 4 * 4)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()  # 이미지 생성이므로 [-1, 1] 범위로 출력
        )

        self.attention = nn.MultiheadAttention(embed_dim=image_size * image_size, num_heads=1, batch_first=True)
        self.attn_linear = nn.Sequential(
            nn.Linear(3* image_size * image_size, latent_dim),
            nn.SiLU(),
        )
    def encode(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar, attention_map):
        std = torch.exp(0.5 * logvar)  # 표준편차 계산
        eps = torch.randn_like(std)     # 노이즈 생성
        
        # Attention 맵을 0에서 1 사이로 정규화
        attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
        
        # Attention 맵을 반전하여 높은 값은 0, 낮은 값은 1로 변환
        positive_mask = 1 - attention_map  # 높은 부분을 0으로, 낮은 부분을 1로 설정
    
        # reparameterize 결과에 attention이 반영되도록 수정
        return mu + std * (eps * positive_mask)  # positive_mask가 낮은 부분만 노이즈가 추가됨


    def decode(self, z):
        h = self.fc_decode(z)
        h = h.view(-1, 512, 4, 4)
        return self.decoder(h)

    def forward(self, x):
        # torch.Size([128, 3, 64, 64])
        re_x = einops.rearrange(x, 'b c h w -> b c (h w)')
        attention_out1, _ = self.attention(re_x, re_x, re_x)  # Self-Attention 계산 (B, C, D)
        attention_out2 = einops.rearrange(attention_out1, 'b c w -> b (c w)')
        attention_out1 = einops.rearrange(attention_out1, 'b c (h w) -> b c h w', h=64)
        attention_out3 = self.attn_linear(attention_out2)

        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar, attention_out3)

        return self.decode(z), mu, logvar, attention_out1, attention_out3



# VAE 손실 함수 정의
# VAE 학습 함수
def saaae_loss_fn(recon_x, x, mu, logvar, attention_map, criterion):
    #mask = (x != 0).float()  # 입력값에서 0인 값에 대한 마스크 적용

    # 재구성 손실 (MSE) 계산
    recon_loss = criterion(recon_x, x)

    # attention_map이 양수일 때는 KL 손실을 계산하지 않도록 마스크 적용
    #kl_mask = (attention_map <= 0).float()  
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    # 총 손실 계산
    return recon_loss + kl_loss



In [6]:
import torch
import torch.nn.functional as F
import torchvision.utils as vutils
import numpy as np
import os
import matplotlib.pyplot as plt

# 랜덤 패치 마스킹 함수
def random_mask(images, mask_size=8, num_patches=16):
    masked_images = images.clone()
    _, _, height, width = images.size()
    seed = 42
    torch.manual_seed(seed)
    for _ in range(num_patches):
        top = np.random.randint(0, height - mask_size)
        left = np.random.randint(0, width - mask_size)
        masked_images[:, :, top:top + mask_size, left:left + mask_size] = 0
    return masked_images

def calculate_psnr(reconstructed, original):
    mse = F.mse_loss(reconstructed, original, reduction='none')
    mse = mse.mean(dim=[1, 2, 3])  # 각 이미지의 MSE
    psnr = 20 * torch.log10(1.0 / torch.sqrt(mse))  # PSNR 계산
    return psnr

# Evaluate VAE with MSE and PSNR, and save masked images and attention maps
def evaluate_vae(vae, data_loader, device, save_dir):
    vae.eval()
    total_mse = 0.0
    total_psnr = 0.0
    total_images = 0
    os.makedirs(save_dir, exist_ok=True)

def train(model, train_loader,val_loader, num_epochs,criterion, save_dir, device):
    # 생성된 이미지를 저장할 디렉터리 설정
    
    best_loss = float('inf')
    avg_loss = float('inf')
    losses, psnrs, mses, ssims = [], [], [], []
    best_model = None
    os.makedirs(save_dir, exist_ok=True)
    ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for batch_idx, (images, _) in enumerate(train_loader):
            images = images.to(device)
            
            # 이미지에 랜덤 마스킹 적용
            masked_images = random_mask(images)
            
            optimizer.zero_grad()
            
            # VAE에서 복원된 이미지와 잠재 벡터를 얻음
            recon_images, mu, logvar, attention_out1, attention_out3 = model(masked_images)
            loss = saaae_loss_fn(recon_images, images, mu, logvar, attention_out3, criterion)  # 실제 원본 이미지와 비교
            loss.backward()
            train_loss += loss.item()
            optimizer.step()
        avg_loss = train_loss/len(train_loader)
        losses.append(avg_loss)

        if avg_loss<best_loss:
            best_model = model.state_dict()
            best_loss = avg_loss
            
        print(f"Epoch {epoch+1}, Loss: {avg_loss}")
        
        # 100 에포크마다 이미지를 복원하고 저장
        if (epoch) % 100 == 0:
            model.eval()
            with torch.no_grad():
                for batch_idx, (images, _) in enumerate(val_loader):
                    images = images.to(device)
                    
                    # 이미지에 랜덤 마스킹 적용
                    masked_images = random_mask(images)
                    
                    optimizer.zero_grad()
                    
                    # VAE에서 복원된 이미지와 잠재 벡터를 얻음
                    recon_images, mu, logvar, attention_out1, attention_out3 = model(masked_images)


                    # 원본 배치에서 이미지 가져오기
                    original_images = images.cpu()
                    masked_images_cpu = masked_images.cpu()
        
                    reconstructed_images = recon_images.cpu()
                    reconstructed_images = (reconstructed_images + 1) / 2  # [-1, 1] 범위를 [0, 1]로 변환

                    attention_map = (attention_out1 - attention_out1.min()) / (attention_out1.max() - attention_out1.min())
                    # 마스킹 이미지와 Attention 맵을 곱하여 강조
                    #attention_map = (attention_map > 0.5).float()
                    attention_applied_images = masked_images * attention_map
                    attention_applied_images = attention_applied_images.cpu()
                    
                    trans_images = (original_images + 1) / 2  # 동일하게 원본 이미지도 변환





                    
                    
                    # MSE 및 PSNR 계산
                    mse = F.mse_loss(reconstructed_images, trans_images, reduction='sum').item()  # 배치의 MSE
                    psnr = calculate_psnr(reconstructed_images, trans_images).sum().item()  # 배치의 PSNR
                    ssim = ssim_metric(reconstructed_images, trans_images)
                    
                    avg_mses = mse/BATCH_SIZE
                    avg_psnrs = psnr/BATCH_SIZE
                    avg_ssims = ssim

                    mses.append(avg_mses)
                    psnrs.append(avg_psnrs)
                    ssims.append(avg_ssims)
                    
                    print(f"Average MSE per 1 image : {avg_mses}")
                    print(f"Average PSNR per 1 image : {avg_psnrs}")
                    print(f"Average SSIM per 1 image: {avg_ssims}")
                    
        
                    # 이미지 시각화 및 저장
                    fig, axs = plt.subplots(4, 1, figsize=(8, 32))
        
                    # 원본 이미지 시각화
                    axs[0].set_title(f"Original Images at Epoch {epoch+1}")
                    original_grid = vutils.make_grid(original_images, padding=2, normalize=True)
                    axs[0].imshow(np.transpose(original_grid, (1, 2, 0)))
                    axs[0].axis("off")
        
                    # 마스킹된 이미지 시각화
                    axs[1].set_title(f"Masked Images at Epoch {epoch+1}")
                    masked_grid = vutils.make_grid(masked_images_cpu, padding=2, normalize=True)
                    axs[1].imshow(np.transpose(masked_grid, (1, 2, 0)))
                    axs[1].axis("off")
        
                    # 복원된 이미지 시각화
                    axs[2].set_title(f"Reconstructed Images at Epoch {epoch+1}")
                    reconstructed_grid = vutils.make_grid(reconstructed_images, padding=2, normalize=True)
                    axs[2].imshow(np.transpose(reconstructed_grid, (1, 2, 0)))
                    axs[2].axis("off")
        
                    # Attention으로 마스킹된 이미지 시각화
                    axs[3].set_title(f"Attention Masked Images at Epoch {epoch+1}")
                    attention_masked_grid = vutils.make_grid(attention_applied_images, padding=2, normalize=True)
                    axs[3].imshow(np.transpose(attention_masked_grid, (1, 2, 0)))
                    axs[3].axis("off")
        
                    # 이미지 파일로 저장
                    save_path = os.path.join(save_dir, f'SAAAE_masked_reconstructed_epoch_{epoch+1}.png')
                    plt.savefig(save_path)
                    plt.close()  # plt.show() 대신 close()를 사용해 메모리 관리
                    break
    return losses,mses, psnrs,ssims, best_model

In [7]:
save_dir = './generated_saaae_linear'
device = 'cuda'
saaae = SAAAE().to(device)
optimizer = optim.Adam(saaae.parameters(), lr=1e-3)
criterion = nn.MSELoss(reduction='sum')  # Mean Squared Error Loss


losses,mses, psnrs,ssims, best_model = train(saaae, train_loader,val_loader, EPOCHS,criterion, save_dir, device)



Epoch 1, Loss: 732034.034375
Average MSE per 1 image : 587.2535400390625
Average PSNR per 1 image : 13.739505767822266
Average SSIM per 1 image: 0.2369793951511383
Epoch 2, Loss: 585906.0890625
Epoch 3, Loss: 470232.028125
Epoch 4, Loss: 405818.66875
Epoch 5, Loss: 375347.0703125
Epoch 6, Loss: 355327.30625
Epoch 7, Loss: 332177.7375
Epoch 8, Loss: 318493.71953125
Epoch 9, Loss: 304190.99453125
Epoch 10, Loss: 294211.59140625
Epoch 11, Loss: 290152.4171875
Epoch 12, Loss: 276897.6125
Epoch 13, Loss: 265032.346875
Epoch 14, Loss: 257037.04609375
Epoch 15, Loss: 243388.27265625
Epoch 16, Loss: 233737.803125
Epoch 17, Loss: 227653.85625
Epoch 18, Loss: 227568.022265625
Epoch 19, Loss: 222949.516796875
Epoch 20, Loss: 216903.47109375
Epoch 21, Loss: 212587.7234375
Epoch 22, Loss: 212879.36953125
Epoch 23, Loss: 207285.926171875
Epoch 24, Loss: 201854.29609375
Epoch 25, Loss: 199816.2859375
Epoch 26, Loss: 198552.151171875
Epoch 27, Loss: 196435.3609375
Epoch 28, Loss: 199386.38828125
Epoch

In [8]:
# 에포크별 평균 손실 계산
avg_losses = [epoch_loss / (BATCH_SIZE * len(train_loader)) for epoch_loss in losses]

# 평균 손실값을 확인
print(avg_losses)


[142.9753973388672, 114.43478302001954, 91.84219299316406, 79.26145874023437, 73.30997467041016, 69.39986450195313, 64.87846435546875, 62.20580459594727, 59.41230361938476, 57.463201446533205, 56.67039398193359, 54.08156494140625, 51.76413024902344, 50.202548065185546, 47.53677200317382, 45.65191467285156, 44.46364379882813, 44.446879348754884, 43.54482749938965, 42.36395919799805, 41.521039733886724, 41.57800186157227, 40.485532455444336, 39.42466720581054, 39.02661834716797, 38.77971702575684, 38.36628143310547, 38.94265396118164, 37.522045669555666, 36.502584609985355, 36.10452056884766, 35.82735336303711, 35.332769851684574, 35.188144760131834, 34.73895713806152, 35.72604949951172, 34.4162508392334, 34.154549560546876, 33.336317520141606, 33.100962753295896, 33.20973724365234, 32.80037994384766, 32.52139488220215, 32.587163009643554, 31.75792518615723, 31.204191513061524, 31.299560623168947, 31.169397964477536, 31.42755645751953, 30.532183227539065, 30.17274574279785, 29.9121641540

In [9]:
def save_metric_to_file(metric, file_name, save_dir, metric_name):
    os.makedirs(save_dir, exist_ok=True)
    file_path = os.path.join(save_dir, file_name)
    
    with open(file_path, 'w') as f:
        f.write(f"{metric_name}:\n")
        for i, value in enumerate(metric, 1):
            f.write(f"{i} {value}\n")
    
    print(f"{metric_name} saved to {file_path}")

In [10]:
# 각각의 메트릭을 별도 파일에 저장
save_metric_to_file(avg_losses, 'saaae_linear_epoch_losses.txt', save_dir, 'Loss')
save_metric_to_file(mses, 'saaae_linear_mean_squared_errors.txt', save_dir, 'MSE')
save_metric_to_file(psnrs, 'saaae_linear_peak_signal_to_noise_ratios.txt', save_dir, 'PSNR')

Loss saved to ./generated_saaae_linear/saaae_linear_epoch_losses.txt
MSE saved to ./generated_saaae_linear/saaae_linear_mean_squared_errors.txt
PSNR saved to ./generated_saaae_linear/saaae_linear_peak_signal_to_noise_ratios.txt


In [13]:
save_metric_to_file(ssims, 'saaae_linear_ssims.txt', save_dir, 'SSIM')

SSIM saved to ./generated_saaae_linear/saaae_linear_ssims.txt


In [11]:
saaae.load_state_dict(best_model)

<All keys matched successfully>

In [12]:
prn

NameError: name 'prn' is not defined

In [None]:
model.eval()
with torch.no_grad():
    for batch_idx, (images, _) in enumerate(val_loader):
        images = images.to(device)
        
        # 이미지에 랜덤 마스킹 적용
        masked_images = random_mask(images)
        
        optimizer.zero_grad()
        
        # VAE에서 복원된 이미지와 잠재 벡터를 얻음
        recon_images, mu, logvar, attention_out1, attention_out3 = model(masked_images)


        # 원본 배치에서 이미지 가져오기
        original_images = images.cpu()
        masked_images_cpu = masked_images.cpu()

        reconstructed_images = recon_images.cpu()
        reconstructed_images = (reconstructed_images + 1) / 2  # [-1, 1] 범위를 [0, 1]로 변환

        atten_images = attention_out1.cpu()
        atten_images = (atten_images + 1) / 2  # [-1, 1] 범위를 [0, 1]로 변환

        trans_images = (original_images + 1) / 2  # 동일하게 원본 이미지도 변환
        
        # MSE 및 PSNR 계산
        mse = F.mse_loss(reconstructed_images, trans_images, reduction='sum').item()  # 배치의 MSE
        psnr = calculate_psnr(reconstructed_images, trans_images).sum().item()  # 배치의 PSNR

        avg_mses = mse/BATCH_SIZE
        avg_psnrs = psnr/BATCH_SIZE

        mses.append(avg_mses)
        psnrs.append(avg_psnrs)
        print(f"Average MSE per 1 image : {avg_mses}")
        print(f"Average PSNR per 1 image : {avg_psnrs}")

        

        # # 이미지 시각화 및 저장
        # fig, axs = plt.subplots(4, 1, figsize=(8, 32))

        # # 원본 이미지 시각화
        # axs[0].set_title(f"Original Images at Epoch {epoch+1}")
        # original_grid = vutils.make_grid(original_images, padding=2, normalize=True)
        # axs[0].imshow(np.transpose(original_grid, (1, 2, 0)))
        # axs[0].axis("off")

        # # 마스킹된 이미지 시각화
        # axs[1].set_title(f"Masked Images at Epoch {epoch+1}")
        # masked_grid = vutils.make_grid(masked_images_cpu, padding=2, normalize=True)
        # axs[1].imshow(np.transpose(masked_grid, (1, 2, 0)))
        # axs[1].axis("off")

        # # 복원된 이미지 시각화
        # axs[2].set_title(f"Reconstructed Images at Epoch {epoch+1}")
        # reconstructed_grid = vutils.make_grid(reconstructed_images, padding=2, normalize=True)
        # axs[2].imshow(np.transpose(reconstructed_grid, (1, 2, 0)))
        # axs[2].axis("off")

        # # Attention으로 마스킹된 이미지 시각화
        # axs[3].set_title(f"Attention Masked Images at Epoch {epoch+1}")
        # attention_masked_grid = vutils.make_grid(atten_images, padding=2, normalize=True)
        # axs[3].imshow(np.transpose(attention_masked_grid, (1, 2, 0)))
        # axs[3].axis("off")

        # # 이미지 파일로 저장
        # save_path = os.path.join(save_dir, f'SAAAE_masked_reconstructed_epoch_{epoch+1}.png')
        # plt.savefig(save_path)
        # plt.close()  # plt.show() 대신 close()를 사용해 메모리 관리
        # break

In [None]:
import torch.nn.functional as F
import math
import time
import torchvision.utils as vutils
all_time = []
# PSNR 계산 함수 정의


    with torch.no_grad():
        for batch_idx, (images, _) in enumerate(data_loader):
            images = images.to(device)
            masked_images = random_mask(images)
            
            # VAE로 복원된 이미지 얻기
            now = time.time()
            recon_images, mu, logvar, attention_out1, attention_out3 = vae(masked_images)

            attention_out1 = einops.rearrange(attention_out1, 'b c (h w) -> b c h w', h=64)

            print("Inference time:", time.time() - now)
            recon_images = (recon_images + 1) / 2  # [-1, 1] 범위를 [0, 1]로 변환
            images = (images + 1) / 2  # 동일하게 원본 이미지도 변환

            # MSE 및 PSNR 계산
            mse = F.mse_loss(recon_images, images, reduction='sum').item()  # 배치의 MSE
            psnr = calculate_psnr(recon_images, images).sum().item()  # 배치의 PSNR
            
            total_mse += mse
            total_psnr += psnr
            total_images += images.size(0)

            # 마스킹된 이미지 및 attention map 시각화 및 저장
            masked_images_cpu = masked_images.cpu()
            attention_map_cpu = attention_out1.cpu()  # attention_map을 [B, 1, H, W]로 가정
            
            for i in range(masked_images_cpu.size(0)):
                fig, axs = plt.subplots(1, 2, figsize=(8, 4))
                
                # 마스킹된 이미지 저장
                axs[0].set_title("Masked Image")
                masked_img = masked_images_cpu[i]
                masked_grid = vutils.make_grid(masked_img, normalize=True, scale_each=True)
                axs[0].imshow(np.transpose(masked_grid.numpy(), (1, 2, 0)))
                axs[0].axis("off")

                # Attention map 저장
                axs[1].set_title("Masked Image")
                attention_img = attention_map_cpu[i]
                attention_grid = vutils.make_grid(attention_img, normalize=True, scale_each=True)
                axs[1].imshow(np.transpose(attention_grid.numpy(), (1, 2, 0)))
                axs[1].axis("off")
                break
                # # 이미지 파일로 저장
                # save_path = os.path.join(save_dir, f"batch_{batch_idx}_img_{i}_masked_attention.png")
                # plt.savefig(save_path)
                plt.close()

    avg_mse = total_mse / total_images
    avg_psnr = total_psnr / total_images
    return avg_mse, avg_psnr

In [None]:
avg_mse, avg_psnr = evaluate_vae(vae, train_loader, device,"./")
print(f"Epoch {epoch+1}, Average MSE: {avg_mse}, Average PSNR: {avg_psnr} dB")