In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install --upgrade datasets

Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.1.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m15.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl 

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
from datasets import load_dataset
from transformers import CLIPModel, CLIPProcessor
import random

# 1. EuroSAT 데이터셋 로드
dataset = load_dataset('Honaker/eurosat_dataset')

# 라벨 인덱스를 클래스명으로 매핑
label2class = [
    "Annual crop",
    "Forest",
    "Herbaceous vegetation",
    "Highway",
    "Industrial",
    "Pasture",
    "Permanent crop",
    "Residential",
    "River",
    "Sea/Lake"
]

# 이미지 변환 (CLIP 모델 표준 입력 크기: 224x224)
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
                         (0.26862954, 0.26130258, 0.27577711)) # CLIP mean/std
])

def collate_fn(examples):
    images = [transform(x["image"]) for x in examples]
    labels = [x["label"] for x in examples]
    images = torch.stack(images)  # [batch, 3, 224, 224]
    labels = torch.tensor(labels) # [batch]
    return images, labels

# 2. 훈련 데이터셋에서 80개 샘플 선택
train_dataset = dataset["train"]
random.seed(42)  # 재현성을 위해 시드 고정
indices = random.sample(range(len(train_dataset)), 80)
subset_dataset = Subset(train_dataset, indices)

# 3. DataLoader 생성
train_dataloader = DataLoader(subset_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)

# 4. CLIP 모델 로드 및 freeze
model_name = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)

for param in model.parameters():
    param.requires_grad = False
model.eval()

# CLIP 임베딩 차원 확인 (text & image 동일 차원)
embed_dim = model.config.projection_dim

# 학습 대상 weight vector w 정의 (문장 임베딩 차원과 동일)
w = nn.Parameter(torch.zeros(embed_dim, dtype=torch.float32))
optimizer = optim.Adam([w], lr=1e-3)

criterion = nn.CrossEntropyLoss()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
w = w.to(device)

epochs = 100
for epoch in range(epochs):
    running_loss = 0.0
    for images, labels in train_dataloader:
        images = images.to(device)
        labels = labels.to(device)

        texts = [f"A satellite image showing a {label2class[l].lower()}." for l in labels.tolist()]
        text_inputs = processor.tokenizer(
            texts, padding=True, truncation=True, return_tensors="pt"
        ).to(device)

        with torch.no_grad():
            # 이미지 임베딩 추출
            image_embeds = model.get_image_features(pixel_values=images)
            # 텍스트 임베딩 추출
            text_embeds = model.get_text_features(**text_inputs)

        # w를 텍스트 임베딩에 더함 (텍스트 임베딩: [batch, embed_dim])
        # w: [embed_dim] 이므로 broadcast되어 각 배치 샘플 임베딩에 더해진다.
        sentence_embeds = text_embeds + w

        # 노멀라이즈
        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
        sentence_embeds = sentence_embeds / sentence_embeds.norm(p=2, dim=-1, keepdim=True)

        logits = torch.matmul(image_embeds, sentence_embeds.t())
        target = torch.arange(logits.size(0), device=device)
        loss = (criterion(logits, target) + criterion(logits.t(), target)) / 2

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(train_dataloader)
    print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")

    w_sum = w.sum().detach().cpu().item()
    print("w_sum scalar:", w_sum)
    print()


Epoch 1/100, Loss: 1.3758
w_sum scalar: 0.0

Epoch 2/100, Loss: 1.3754
w_sum scalar: 0.0

Epoch 3/100, Loss: 1.3746
w_sum scalar: 0.0

Epoch 4/100, Loss: 1.3755
w_sum scalar: 0.0

Epoch 5/100, Loss: 1.3752
w_sum scalar: 0.0

Epoch 6/100, Loss: 1.3767
w_sum scalar: 0.0

Epoch 7/100, Loss: 1.3749
w_sum scalar: 0.0

Epoch 8/100, Loss: 1.3757
w_sum scalar: 0.0

Epoch 9/100, Loss: 1.3753
w_sum scalar: 0.0

Epoch 10/100, Loss: 1.3774
w_sum scalar: 0.0

Epoch 11/100, Loss: 1.3756
w_sum scalar: 0.0

Epoch 12/100, Loss: 1.3762
w_sum scalar: 0.0

Epoch 13/100, Loss: 1.3759
w_sum scalar: 0.0

Epoch 14/100, Loss: 1.3753
w_sum scalar: 0.0

Epoch 15/100, Loss: 1.3749
w_sum scalar: 0.0

Epoch 16/100, Loss: 1.3759
w_sum scalar: 0.0

Epoch 17/100, Loss: 1.3754
w_sum scalar: 0.0

Epoch 18/100, Loss: 1.3750
w_sum scalar: 0.0

Epoch 19/100, Loss: 1.3752
w_sum scalar: 0.0

Epoch 20/100, Loss: 1.3752
w_sum scalar: 0.0

Epoch 21/100, Loss: 1.3758
w_sum scalar: 0.0

Epoch 22/100, Loss: 1.3753
w_sum scalar: 0.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
from datasets import load_dataset
from transformers import CLIPModel, CLIPProcessor
import random

# EuroSAT 데이터셋 로드
dataset = load_dataset('Honaker/eurosat_dataset')

label2class = [
    "Annual crop",
    "Forest",
    "Herbaceous vegetation",
    "Highway",
    "Industrial",
    "Pasture",
    "Permanent crop",
    "Residential",
    "River",
    "Sea/Lake"
]

transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
                         (0.26862954, 0.26130258, 0.27577711))
])

def collate_fn(examples):
    images = [transform(x["image"]) for x in examples]
    labels = [x["label"] for x in examples]
    images = torch.stack(images)
    labels = torch.tensor(labels)
    return images, labels

# 일부 샘플만 사용(예: 8개)
train_dataset = dataset["train"]
random.seed(42)
indices = random.sample(range(len(train_dataset)), 8)
subset_dataset = Subset(train_dataset, indices)

train_dataloader = DataLoader(subset_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)

model_name = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)

for param in model.parameters():
    param.requires_grad = False
model.eval()

embed_dim = model.config.projection_dim
w = nn.Parameter(torch.zeros(embed_dim, dtype=torch.float32))

optimizer = optim.Adam([w], lr=1e-3)
criterion = nn.CrossEntropyLoss()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
w = w.to(device)

epochs = 3

for epoch in range(epochs):
    running_loss = 0.0
    last_token_weight_sums = None
    last_texts = None
    last_input_ids = None

    for images, labels in train_dataloader:
        images = images.to(device)
        labels = labels.to(device)

        texts = [f"A satellite image showing a {label2class[l].lower()}." for l in labels.tolist()]

        text_inputs = processor.tokenizer(
            texts, padding=True, truncation=True, return_tensors="pt"
        ).to(device)

        with torch.no_grad():
            image_embeds = model.get_image_features(pixel_values=images)
            text_outputs = model.text_model(**text_inputs)
            last_hidden_state = text_outputs.last_hidden_state  # [batch, seq_len, embed_dim]

        # w 더하기
        last_hidden_state = last_hidden_state + w

        # 토큰별 weight sum
        token_weight_sums = last_hidden_state.sum(dim=-1)  # [batch, seq_len]

        # 문장 임베딩: 단순 평균 풀링 (예시)
        sentence_embeds = last_hidden_state.mean(dim=1)

        # 정규화
        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
        sentence_embeds = sentence_embeds / sentence_embeds.norm(p=2, dim=-1, keepdim=True)

        logits = torch.matmul(image_embeds, sentence_embeds.t())
        target = torch.arange(logits.size(0), device=device)
        loss = (criterion(logits, target) + criterion(logits.t(), target)) / 2

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        last_token_weight_sums = token_weight_sums
        last_texts = texts
        last_input_ids = text_inputs["input_ids"]

    avg_loss = running_loss / len(train_dataloader)
    print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")

    # Epoch 종료 후 마지막 batch의 토큰별 weight sum 출력
    # 토큰 해석: 토크나이저의 decode 이용
    if last_token_weight_sums is not None:
        for i in range(last_token_weight_sums.size(0)):
            print(f"Text {i+1}: {last_texts[i]}")
            token_ids = last_input_ids[i].detach().cpu().tolist()
            tokens = processor.tokenizer.convert_ids_to_tokens(token_ids)
            # 패딩이나 스페셜 토큰 제거를 위해 실제 Attention Mask나 특수 토큰 처리 필요할 수도 있음.
            # 여기서는 단순히 토큰 시퀀스 전체를 출력
            print("Tokens:", tokens)
            print("Token sums:", last_token_weight_sums[i].detach().cpu().numpy())
            print()


Epoch 1/3, Loss: 1.3851
Text 1: A satellite image showing a highway.
Tokens: ['<|startoftext|>', 'a</w>', 'satellite</w>', 'image</w>', 'showing</w>', 'a</w>', 'highway</w>', '.</w>', '<|endoftext|>', '<|endoftext|>']
Token sums: [54.16202  56.88982  60.224625 66.41615  63.714893 63.593544 63.531567
 62.376896 62.487156 62.654423]

Text 2: A satellite image showing a highway.
Tokens: ['<|startoftext|>', 'a</w>', 'satellite</w>', 'image</w>', 'showing</w>', 'a</w>', 'highway</w>', '.</w>', '<|endoftext|>', '<|endoftext|>']
Token sums: [54.16202  56.88982  60.224625 66.41615  63.714893 63.593544 63.531567
 62.376896 62.487156 62.654423]

Text 3: A satellite image showing a annual crop.
Tokens: ['<|startoftext|>', 'a</w>', 'satellite</w>', 'image</w>', 'showing</w>', 'a</w>', 'annual</w>', 'crop</w>', '.</w>', '<|endoftext|>']
Token sums: [54.16202  56.88982  60.224625 66.41615  63.714893 63.593544 62.018616
 62.51574  61.727493 63.16903 ]

Text 4: A satellite image showing a forest.
Toke

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
from datasets import load_dataset
from transformers import CLIPModel, CLIPProcessor
import random

# EuroSAT 데이터셋 로드
dataset = load_dataset('Honaker/eurosat_dataset')

label2class = [
    "Annual crop",
    "Forest",
    "Herbaceous vegetation",
    "Highway",
    "Industrial",
    "Pasture",
    "Permanent crop",
    "Residential",
    "River",
    "Sea/Lake"
]

transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
                         (0.26862954, 0.26130258, 0.27577711))
])

def collate_fn(examples):
    images = [transform(x["image"]) for x in examples]
    labels = [x["label"] for x in examples]
    images = torch.stack(images)
    labels = torch.tensor(labels)
    return images, labels

# 일부 샘플만 사용 (예: 8개)
train_dataset = dataset["train"]
random.seed(42)
indices = random.sample(range(len(train_dataset)), 8)
subset_dataset = Subset(train_dataset, indices)

train_dataloader = DataLoader(subset_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)

model_name = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)

for param in model.parameters():
    param.requires_grad = False
model.eval()

embed_dim = model.config.projection_dim

# 최대 토큰 길이를 정해야 합니다.
# 여기서는 예시로 fixed length(= 77)로 가정. CLIP의 기본 최대 토큰 길이는 일반적으로 77 토큰 정도입니다.
max_seq_len = 77
w = nn.Parameter(torch.zeros(max_seq_len, embed_dim, dtype=torch.float32))

optimizer = optim.Adam([w], lr=1e-3)
criterion = nn.CrossEntropyLoss()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
w = w.to(device)

epochs = 50

for epoch in range(epochs):
    running_loss = 0.0
    for images, labels in train_dataloader:
        images = images.to(device)
        labels = labels.to(device)

        texts = [f"A satellite image showing a {label2class[l].lower()}." for l in labels.tolist()]

        text_inputs = processor.tokenizer(
            texts, padding=True, truncation=True, return_tensors="pt"
        ).to(device)

        with torch.no_grad():
            image_embeds = model.get_image_features(pixel_values=images)
            text_outputs = model.text_model(**text_inputs)
            # last_hidden_state: [batch, seq_len, embed_dim]
            orig_hidden_state = text_outputs.last_hidden_state

        # orig_hidden_state.shape: [batch, seq_len, embed_dim]
        # w.shape: [seq_len, embed_dim]
        # 각 토큰별로 동일한 w를 더하기 위해 broadcasting 사용
        # batch 차원: broadcast
        # last_hidden_state = orig_hidden_state + w
        # 여기서 seq_len은 실제 문장 토큰 길이보다 클 수 있으므로 attention_mask 등을 이용해 실제 토큰 길이에 맞게 적용 가능
        # 단순히 앞 부분 토큰만 w 추가한다고 가정
        seq_length = orig_hidden_state.size(1)
        # 입력된 seq_length가 max_seq_len 이하라고 가정
        last_hidden_state = orig_hidden_state + w[:seq_length, :]

        # 평균 풀링으로 문장 임베딩 만들기 (예시)
        sentence_embeds = last_hidden_state.mean(dim=1)
        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
        sentence_embeds = sentence_embeds / sentence_embeds.norm(p=2, dim=-1, keepdim=True)

        logits = torch.matmul(image_embeds, sentence_embeds.t())
        target = torch.arange(logits.size(0), device=device)
        loss = (criterion(logits, target) + criterion(logits.t(), target)) / 2

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(train_dataloader)
    print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")

    # 매 epoch마다 w_sum 계산 및 출력
    # w_sum: 각 토큰별로 embed_dim 방향 합
    w_sum = w.sum(dim=1)  # [seq_len]
    print("w_sum vector:", w_sum.detach().cpu().numpy())
    print()


Epoch 1/50, Loss: 1.3847
w_sum vector: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0.]

Epoch 2/50, Loss: 1.3852
w_sum vector: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0.]

Epoch 3/50, Loss: 1.3852
w_sum vector: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0.]

Epoch 4/50, Loss: 1.3844
w_sum vector: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
from datasets import load_dataset
from transformers import CLIPModel, CLIPProcessor
import random

# EuroSAT 데이터셋 로드
dataset = load_dataset('Honaker/eurosat_dataset')

label2class = [
    "Annual crop",
    "Forest",
    "Herbaceous vegetation",
    "Highway",
    "Industrial",
    "Pasture",
    "Permanent crop",
    "Residential",
    "River",
    "Sea/Lake"
]

transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
                         (0.26862954, 0.26130258, 0.27577711))
])

def collate_fn(examples):
    images = [transform(x["image"]) for x in examples]
    labels = [x["label"] for x in examples]
    images = torch.stack(images)
    labels = torch.tensor(labels)
    return images, labels

# 일부 샘플만 사용 (예: 8개)
train_dataset = dataset["train"]
random.seed(42)
indices = random.sample(range(len(train_dataset)), 8)
subset_dataset = Subset(train_dataset, indices)

train_dataloader = DataLoader(subset_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)

model_name = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)

# 모델 파라미터 freeze
for param in model.parameters():
    param.requires_grad = False
model.eval()

embed_dim = model.config.projection_dim

# w를 [max_seq_len, embed_dim] 형태로 정의
max_seq_len = 77
w = nn.Parameter(torch.zeros(max_seq_len, embed_dim, dtype=torch.float32))

optimizer = optim.Adam([w], lr=1e-3)
criterion = nn.CrossEntropyLoss()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
w = w.to(device)

epochs = 3

for epoch in range(epochs):
    running_loss = 0.0
    for images, labels in train_dataloader:
        images = images.to(device)
        labels = labels.to(device)

        texts = [f"A satellite image showing a {label2class[l].lower()}." for l in labels.tolist()]

        text_inputs = processor.tokenizer(
            texts, padding=True, truncation=True, return_tensors="pt"
        ).to(device)

        # 그래디언트 추적 활성화 (no_grad 제거)
        # 모델 파라미터는 requires_grad=False이므로 업데이트 안됨. 그러나 연산은 graph에 포함.
        text_outputs = model.text_model(**text_inputs)
        orig_hidden_state = text_outputs.last_hidden_state  # [batch, seq_len, embed_dim]

        image_embeds = model.get_image_features(pixel_values=images)

        seq_length = orig_hidden_state.size(1)
        # w[:seq_length, :]를 각 문장에 동일하게 적용
        # shape 방송: [batch, seq_len, embed_dim]
        last_hidden_state = orig_hidden_state + w[:seq_length, :]

        # 문장 임베딩 계산 (단순 평균)
        sentence_embeds = last_hidden_state.mean(dim=1)
        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
        sentence_embeds = sentence_embeds / sentence_embeds.norm(p=2, dim=-1, keepdim=True)

        logits = torch.matmul(image_embeds, sentence_embeds.t())
        target = torch.arange(logits.size(0), device=device)
        loss = (criterion(logits, target) + criterion(logits.t(), target)) / 2

        optimizer.zero_grad()
        loss.backward()

        # w.grad 확인 (선택적)
        print("w.grad:", w.grad)

        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(train_dataloader)
    print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")

    # 매 epoch마다 w_sum 벡터 출력
    w_sum = w.sum(dim=1)  # [seq_len] 형태
    print("w_sum vector:", w_sum.detach().cpu().numpy())
    print()


  print("w.grad:", w.grad)


w.grad: None
w.grad: None
Epoch 1/3, Loss: 1.3851
w_sum vector: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0.]

w.grad: None
w.grad: None
Epoch 2/3, Loss: 1.3847
w_sum vector: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0.]

w.grad: None
w.grad: None
Epoch 3/3, Loss: 1.3848
w_sum vector: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0.]



In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
from datasets import load_dataset
from transformers import CLIPModel, CLIPProcessor
import random

# EuroSAT 데이터셋 로드
dataset = load_dataset('Honaker/eurosat_dataset')

label2class = [
    "Annual crop",
    "Forest",
    "Herbaceous vegetation",
    "Highway",
    "Industrial",
    "Pasture",
    "Permanent crop",
    "Residential",
    "River",
    "Sea/Lake"
]

transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
                         (0.26862954, 0.26130258, 0.27577711))
])

def collate_fn(examples):
    images = [transform(x["image"]) for x in examples]
    labels = [x["label"] for x in examples]
    images = torch.stack(images)
    labels = torch.tensor(labels)
    return images, labels

# 일부 샘플만 사용 (예: 8개)
train_dataset = dataset["train"]
random.seed(42)
indices = random.sample(range(len(train_dataset)), 8)
subset_dataset = Subset(train_dataset, indices)

train_dataloader = DataLoader(subset_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)

model_name = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)

for param in model.parameters():
    param.requires_grad = False
model.eval()

embed_dim = model.config.projection_dim

# 최대 토큰 길이
max_seq_len = 77
w = nn.Parameter(torch.zeros(max_seq_len, embed_dim, dtype=torch.float32))

optimizer = optim.Adam([w], lr=1e-3)
criterion = nn.CrossEntropyLoss()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
w = w.to(device)

epochs = 50

for epoch in range(epochs):
    running_loss = 0.0
    for images, labels in train_dataloader:
        images = images.to(device)
        labels = labels.to(device)

        texts = [f"A satellite image showing a {label2class[l].lower()}." for l in labels.tolist()]
        text_inputs = processor.tokenizer(
            texts, padding=True, truncation=True, return_tensors="pt"
        ).to(device)

        with torch.no_grad():
            image_embeds = model.get_image_features(pixel_values=images)

        # 토큰 임베딩 및 포지션 임베딩 직접 계산
        input_ids = text_inputs.input_ids
        attention_mask = text_inputs.attention_mask
        input_shape = input_ids.size()  # [batch_size, seq_length]
        batch_size, seq_length = input_shape

        with torch.no_grad():
            inputs_embeds = model.text_model.embeddings.token_embedding(input_ids)
            position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).expand(batch_size, -1)
            position_embeddings = model.text_model.embeddings.position_embedding(position_ids)
            hidden_states = inputs_embeds + position_embeddings
            # dropout 제거됨

        # w 추가
        hidden_states = hidden_states + w[:seq_length, :]

        # 인코더 통과
        encoder_outputs = model.text_model.encoder(
            hidden_states,
            attention_mask=attention_mask,
            output_attentions=False,
            output_hidden_states=False,
            return_dict=True,
        )

        last_hidden_state = encoder_outputs.last_hidden_state
        # final_layer_norm 적용
        last_hidden_state = model.text_model.final_layer_norm(last_hidden_state)

        # 평균 풀링으로 문장 임베딩
        sentence_embeds = last_hidden_state.mean(dim=1)
        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
        sentence_embeds = sentence_embeds / sentence_embeds.norm(p=2, dim=-1, keepdim=True)

        logits = torch.matmul(image_embeds, sentence_embeds.t())
        target = torch.arange(logits.size(0), device=device)
        loss = (criterion(logits, target) + criterion(logits.t(), target)) / 2

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(train_dataloader)
    print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")

    w_sum = w.sum(dim=1).detach().cpu().numpy()
    print("w_sum vector:", w_sum)
    print()


RuntimeError: Expected attn_mask dtype to be bool or float or to match query dtype, but got attn_mask.dtype: long int and  query.dtype: float instead.