In [5]:
import os

# CUDA 환경 변수 설정
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '5'

# 환경 변수가 올바르게 설정되었는지 확인
print("CUDA_DEVICE_ORDER:", os.environ['CUDA_DEVICE_ORDER'])
print("CUDA_VISIBLE_DEVICES:", os.environ['CUDA_VISIBLE_DEVICES'])


CUDA_DEVICE_ORDER: PCI_BUS_ID
CUDA_VISIBLE_DEVICES: 5


In [9]:
import os
import torch
import torch.nn as nn
import torch.optim as optim

# CUDA 환경 변수 설정
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '5'

# 간단한 PyTorch 모델 정의
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 모델 초기화
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 가짜 데이터 생성
x_train = torch.randn(60000, 784).to(device)
y_train = torch.randint(0, 10, (60000,)).to(device)

# 모델 학습
for epoch in range(10):
    optimizer.zero_grad()
    outputs = model(x_train)
    loss = criterion(outputs, y_train)
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item()}")

# 예측 수행
with torch.no_grad():
    predictions = model(x_train[:10])
    print(predictions)


Epoch 1, Loss: 2.328396797180176
Epoch 2, Loss: 2.2985475063323975
Epoch 3, Loss: 2.27734375
Epoch 4, Loss: 2.259937047958374
Epoch 5, Loss: 2.2429463863372803
Epoch 6, Loss: 2.2260639667510986
Epoch 7, Loss: 2.2099380493164062
Epoch 8, Loss: 2.194953441619873
Epoch 9, Loss: 2.1810970306396484
Epoch 10, Loss: 2.1680431365966797
tensor([[ 0.1423, -0.2815,  0.0374, -0.4290,  0.4429,  0.0083, -0.3355, -0.0238,
          0.0163, -0.4185],
        [-0.0058, -0.1347, -0.5305,  0.2243,  0.2411, -0.0965, -0.1810, -0.1921,
         -0.2580,  0.0395],
        [ 0.0966, -0.1102, -0.2524,  0.3128,  0.0270,  0.5130, -0.2325,  0.4672,
         -0.0454, -0.2259],
        [-0.0533, -0.0982,  0.2578, -0.1277, -0.1489, -0.5505, -0.5916, -0.1866,
          0.5024,  0.0567],
        [-0.2724,  0.0835,  0.0226, -0.0340,  0.0372, -0.2027, -0.1497, -0.2033,
         -0.3692, -0.1221],
        [-0.2431, -0.3960, -0.6209, -0.0595, -0.8505, -0.1848,  0.0186, -0.0601,
          0.6434,  0.3206],
        [ 0.1074