In [88]:
import copy
import torch
import torch.nn as nn
import torch.optim as optim

In [89]:
# 간단한 모델 정의
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.layer1 = nn.Linear(10, 5)  # Freeze 대상
        self.layer2 = nn.Linear(5, 1)  # 학습 대상

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        return x

# 모델 복사 (두 방법을 독립적으로 실험하기 위해)
model1 = SimpleModel()  # requires_grad=False
model2 = copy.deepcopy(model1)  # learning_rate=0

In [90]:
# 손실 함수
criterion = nn.MSELoss()

# Optimizer 설정
optimizer1 = optim.SGD([
    {'params': model1.layer1.parameters(), 'lr': 0.1},  # layer1 학습
    {'params': model1.layer2.parameters(), 'lr': 0.1}  # layer2 학습
])

optimizer2 = optim.SGD([
    {'params': model2.layer1.parameters(), 'lr': 0.0},  # layer1 Freeze
    {'params': model2.layer2.parameters(), 'lr': 0.1}  # 학습
])

# 1. requires_grad=False 실험
for param in model1.layer1.parameters():
    param.requires_grad = False

# 입력 데이터와 타깃 생성
x = torch.randn(8, 10)  # 입력 데이터
y = torch.randn(8, 1)   # 타깃 데이터

In [91]:
# Forward Pass
output1 = model1(x)
loss1 = criterion(output1, y)

output2 = model2(x)
loss2 = criterion(output2, y)

# Backward Pass
loss1.backward()
loss2.backward()

In [92]:
# Optimizer Step
optimizer1.step()  # model1 업데이트
optimizer2.step()  # model2 업데이트

# Weight 비교

In [93]:
# Compare layer1 weights
print("\nlayer1.weight values comparison:")
# print("model1 (requires_grad=False):", model1.layer1.weight.data)
# print("model2 (learning_rate=0):", model2.layer1.weight.data)
print("Are weights equal?", torch.allclose(model1.layer1.weight.data, model2.layer1.weight.data))


# Compare layer2 weights
print("\nlayer2.weight values comparison:")
# print("model1 (requires_grad=False):", model1.layer2.weight.data)
# print("model2 (learning_rate=0):", model2.layer2.weight.data)
print("Are weights equal?", torch.allclose(model1.layer2.weight.data, model2.layer2.weight.data))


layer1.weight values comparison:
Are weights equal? True

layer2.weight values comparison:
Are weights equal? True
