In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import datetime
import os, sys
import torch.nn.functional as F
from PIL import Image
#-----------------------------------------------------------------------------------------
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import torchvision.transforms as transforms

# 학습에 사용할 CPU나 GPU, MPS 장치를 얻습니다.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

In [None]:

# 데이터셋 다운로드 및 전처리
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(mean=(0.5, 0.5, 0.5),
                                std=(0.5, 0.5, 0.5))]
)
cifar10_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)

cifar_dataloader_loader = DataLoader(dataset=cifar10_dataset, batch_size=64, shuffle=True, drop_last=True)

#data_loader = torch.utils.data.DataLoader(cifar10_dataset, batch_size=1, shuffle=True)

# 클래스 레이블과 레이블 이름
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# 데이터 로더에서 임의의 배치 가져오기
data_iter = iter(cifar_dataloader_loader)
images, labels = next(data_iter)

# 이미지와 레이블 출력
image = images[0]
label = labels[0].item()

plt.imshow(image.permute(1, 2, 0))  # 이미지의 채널 순서 변경
plt.title(classes[label])
plt.show()


In [None]:

for X, y in cifar_dataloader_loader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break



In [None]:

class critics_brain(nn.Module):
    def __init__(self, num_classes=10):
        super(critics_brain, self).__init__()
        
        self.fc = nn.Sequential(
            nn.Linear(8*8*64, 512),
            nn.Dropout(p=0.5),
            nn.Linear(512, num_classes),
            #nn.Sigmoid()
        )

    def forward(self, x):
        y = x.view(x.size(0), -1) # (N, 64*8*8)
        y = self.fc(y)
        return y


class critics_right_eye(nn.Module):
    def __init__(self):
        super(critics_right_eye, self).__init__()
        self.conv = nn.Sequential(
            # (N, 1, 32, 32)
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            # (N, 32, 16, 16)
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            # (N, 64, 8, 8)
        )
    
    def forward(self, x):
        y = self.conv(x) # (N, 64, 8, 8)
        return y

class critics_left_eye(nn.Module):
    def __init__(self):
        super(critics_left_eye, self).__init__()
        self.conv = nn.Sequential(
            # (N, 1, 16, 16)
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            #nn.MaxPool2d(2, 2),
            # (N, 32, 16, 16)
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            # (N, 64, 8, 8)
        )
    
    def forward(self, x):
        y = self.conv(x) # (N, 64, 8, 8)
        return y


class critic(nn.Module):
    def __init__(self, num_classes=10):
        super(critic, self).__init__()
        self.conv = nn.Sequential(
            # (N, 1, 32, 32)
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            # (N, 32, 16, 16)
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            # (N, 64, 8, 8)
        )
        self.fc = nn.Sequential(
            nn.Linear(8*8*64, 512),
            nn.Dropout(p=0.5),
            nn.Linear(512, num_classes),
        )
    
    def forward(self, x):
        y = self.conv(x) # (N, 64, 8, 8)
        y = y.view(y.size(0), -1) # (N, 64*8*8)
        y = self.fc(y)
        return y


In [None]:
model = critic().to(device)
#
print(images[0][0].type)
model(images[0].unsqueeze(0))



In [None]:
Brain = critics_brain().to(device)
Righteye = critics_right_eye().to(device)


criterion = nn.CrossEntropyLoss()
optim_Brain = torch.optim.Adam(Brain.parameters(), lr=0.001)
optim_Righteye = torch.optim.Adam(Righteye.parameters(), lr=0.001)


max_epoch = 5
step = 0


In [None]:
for epoch in range(max_epoch):
    for idx, (images, labels) in enumerate(cifar_dataloader_loader):
        # Training Discriminator
        x, y = images.to(device), labels.to(device) # (N, 3, 32, 32), (N = Batch Size)
        y_hat = Brain(Righteye(x)) # (N, 10)

        loss = criterion(y_hat, y)
        
        optim_Brain.zero_grad()
        optim_Righteye.zero_grad()
        loss.backward()
        optim_Brain.step()
        optim_Righteye.step()
        
        if step % 10 == 0:
            Brain.eval()
            Righteye.eval()
            acc = 0.
            with torch.no_grad():
                for idx, (images, labels) in enumerate(cifar_dataloader_loader):
                    x, y = images.to(device), labels.to(device) # (N, 1, 28, 28), (N, )
                    y_hat = Brain(Righteye(x)) # (N, 10)
                    loss = criterion(y_hat, y)
                    _, indices = torch.max(y_hat, dim=-1)
                    acc += torch.sum(indices == y).item()
            print('-'*20, 'Test', '-'*20)
            print('Step: {}, Loss: {}, Accuracy: {} %'.format(step, loss.item(), acc/len(cifar10_dataset)*100))
            print('-'*20)
            Brain.train()
            Righteye.train()
            
        step += 1
