In [9]:
import pandas as pd
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import TensorDataset, DataLoader
from torch.nn import CrossEntropyLoss
import time

# 데이터 로드 및 전처리
data_A = pd.read_csv("filtered_patients.csv")
data_B = pd.read_csv("filtered_conditions.csv")
data_merged = pd.merge(data_A, data_B, on="ID", how="inner")

selected_columns = ['BIRTHDATE', 'GENDER', 'RACE', 'ETHNICITY', 'DESCRIPTION']
data_selected = data_merged[selected_columns]

# COVID-19 감염 여부 판단
data_selected['INFECTED'] = data_selected['DESCRIPTION'].apply(lambda x: 1 if 'COVID-19' in str(x) else 0)

# 학습 데이터와 테스트 데이터로 나누기
X = data_selected[['BIRTHDATE', 'GENDER', 'RACE', 'ETHNICITY']]
y = data_selected['INFECTED']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# BERT 토크나이저 불러오기
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def tokenize_data(data):
    texts = [' '.join(map(str, [d])) for d in data]
    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
    return inputs

# 데이터셋 토큰화
train_inputs = tokenize_data(X_train)
test_inputs = tokenize_data(X_test)

# 레이블 데이터 변환
label_encoder = LabelEncoder()
y_train = label_encoder.fit_transform(y_train)
y_test = label_encoder.transform(y_test)

# Tensor 생성
train_labels = torch.tensor(y_train, dtype=torch.long)
test_labels = torch.tensor(y_test, dtype=torch.long)

# 데이터셋 크기 확인
if len(train_inputs['input_ids']) != len(train_inputs['attention_mask']) or len(train_inputs['input_ids']) != len(train_labels):
    raise ValueError("Size mismatch between tensors")

# DataLoader 생성
train_dataset = TensorDataset(train_inputs['input_ids'], train_inputs['attention_mask'], train_labels)
test_dataset = TensorDataset(test_inputs['input_ids'], test_inputs['attention_mask'], test_labels)

# 적절한 batch size 선택
train_batch_size = 64
test_batch_size = 64

# DataLoader 생성
train_dataloader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

class CustomBertForSequenceClassification(BertForSequenceClassification):
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        labels=None,
        output_hidden_states=True
    ):
        outputs = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            labels=labels,
            output_hidden_states=output_hidden_states
        )
        logits = outputs.logits
        hidden_states = outputs.hidden_states[-1]
        loss = outputs.loss
        return logits, loss, hidden_states

model = CustomBertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(torch.cuda.is_available())
model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
criterion = CrossEntropyLoss()

# BERT 모델 학습
start_time = time.time()
model.train()
for epoch in range(3):
    total_loss = 0
    for batch_idx, batch in enumerate(train_dataloader):
        optimizer.zero_grad()
        input_ids, attention_mask, labels = batch
        input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
        logits, loss, _ = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
        print(f"Epoch {epoch+1}, Batch {batch_idx+1}/{len(train_dataloader)}, Loss: {loss.item()}")  # 각 배치의 손실값 출력
    average_loss = total_loss / (batch_idx + 1)  # 배치의 수로 나눠줌
    print(f"Epoch {epoch+1} Average Loss: {average_loss}")

end_time = time.time()
print("Training complete.")
print(f"Training time: {end_time - start_time} seconds")

# BERT 모델 평가
model.eval()
correct_predictions = 0
total_predictions = 0
with torch.no_grad():
    for batch in test_dataloader:
        input_ids, attention_mask, labels = batch
        input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
        logits, _, _ = model(input_ids=input_ids, attention_mask=attention_mask)
        predictions = torch.argmax(logits, dim=1)
        correct_predictions += torch.sum(predictions == labels).item()
        total_predictions += len(labels)
accuracy = correct_predictions / total_predictions
print(f"BERT Model Accuracy: {accuracy}")


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  data_selected['INFECTED'] = data_selected['DESCRIPTION'].apply(lambda x: 1 if 'COVID-19' in str(x) else 0)


ValueError: Size mismatch between tensors

In [10]:
import os
import pandas as pd
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import TensorDataset, DataLoader
from torch.nn import CrossEntropyLoss
from sklearn.ensemble import RandomForestClassifier
import time


# Step 1: 데이터 로드 및 전처리
data_A = pd.read_csv("filtered_patients.csv")
data_B = pd.read_csv("filtered_conditions.csv")
data_merged = pd.merge(data_A, data_B, on="ID", how="inner")

# print(data_merged)

selected_columns = ['ID', 'BIRTHDATE', 'GENDER', 'ADDRESS', 'DESCRIPTION']
def classify_covid19(description):
    if "COVID-19" in description:
        return 1
    elif "Suspected COVID-19" in description:
        return 1
    else:
        return 0

data_processed = data_merged[selected_columns].copy()
data_processed['GENDER'] = LabelEncoder().fit_transform(data_processed['GENDER'])
data_processed['COVID-19'] = data_processed['DESCRIPTION'].apply(classify_covid19)

# Step 2: 데이터 분할
X_train, X_test, y_train, y_test = train_test_split(data_processed['DESCRIPTION'], data_processed['COVID-19'], test_size=0.2, random_state=42)

# 테스트 데이터셋의 라벨 분포 확인
# print(y_test.value_counts())

# Step 3: BERT 모델 학습 및 평가
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def tokenize_data(data):
    texts = [' '.join(map(str, [d])) for d in data]
    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
    return inputs

# 데이터셋 토큰화
train_inputs = tokenize_data(X_train)
test_inputs = tokenize_data(X_test)

# 레이블 데이터 변환
label_encoder = LabelEncoder()
y_train = label_encoder.fit_transform(y_train)
y_test = label_encoder.transform(y_test)

# Tensor 생성
train_labels = torch.tensor(y_train, dtype=torch.long)
test_labels = torch.tensor(y_test, dtype=torch.long)

# DataLoader 생성
train_batch_size = 64
test_batch_size = 64

train_dataset = TensorDataset(train_inputs['input_ids'], train_inputs['attention_mask'], train_labels)
test_dataset = TensorDataset(test_inputs['input_ids'], test_inputs['attention_mask'], test_labels)
train_dataloader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

class CustomBertForSequenceClassification(BertForSequenceClassification):
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        labels=None,
        output_hidden_states=True
    ):
        outputs = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            labels=labels,
            output_hidden_states=output_hidden_states
        )
        logits = outputs.logits
        hidden_states = outputs.hidden_states[-1]
        loss = outputs.loss
        return logits, loss, hidden_states

# 모델 생성 또는 불러오기
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = 'bert_model.pth'
save_hidden_states_path = 'hidden_states.npy'  # 추가된 부분: hidden states를 저장할 경로

if os.path.exists(model_path):
    model = CustomBertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
    model.load_state_dict(torch.load(model_path))
    model.to(device)  # 모델을 GPU 또는 CPU로 이동
    print("Model loaded.")
else:
    model = CustomBertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(torch.cuda.is_available())
    model.to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    criterion = CrossEntropyLoss()

    # BERT 모델 학습
    start_time = time.time()
    model.train()
    for epoch in range(3):
        total_loss = 0
        for batch_idx, batch in enumerate(train_dataloader):
            optimizer.zero_grad()
            input_ids, attention_mask, labels = batch
            input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
            logits, loss, _ = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            total_loss += loss.item()
            loss.backward()
            optimizer.step()
            print(f"Epoch {epoch+1}, Batch {batch_idx+1}/{len(train_dataloader)}, Loss: {loss.item()}")  # 각 배치의 손실값 출력
        average_loss = total_loss / (batch_idx + 1)  # 배치의 수로 나눠줌
        print(f"Epoch {epoch+1} Average Loss: {average_loss}")

    end_time = time.time()
    print("Training complete.")
    print(f"Training time: {end_time - start_time} seconds")

    # 모델 저장
    torch.save(model.state_dict(), model_path)
    
     # 추가된 부분: hidden states 저장
    np.save(save_hidden_states_path, train_features)  # 훈련 데이터셋의 hidden states를 저장

    # BERT 모델 평가
    model.eval()
    correct_predictions = 0
    total_predictions = 0
    with torch.no_grad():
        for batch in test_dataloader:
            input_ids, attention_mask, labels = batch
            input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
            logits, _, _ = model(input_ids=input_ids, attention_mask=attention_mask)
            predictions = torch.argmax(logits, dim=1)
            correct_predictions += torch.sum(predictions == labels).item()
            total_predictions += len(labels)
    accuracy = correct_predictions / total_predictions
    print(f"BERT Model Accuracy: {accuracy}")



# Step 4: BERT 모델의 hidden states를 활용하여 새로운 모델 학습 및 평가
train_features, train_labels = [], []
test_features, test_labels = [], []

def extract_features(data_loader, model):
    model.eval()
    features = []
    labels = []
    with torch.no_grad():
        for batch in data_loader:
            input_ids, attention_mask, label = batch
            input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)
            _, _, hidden_states = model(input_ids=input_ids, attention_mask=attention_mask)
            features.append(hidden_states.cpu().numpy())
            labels.append(label.cpu().numpy())
    features = np.concatenate(features)
    labels = np.concatenate(labels)
    return features, labels

# BERT 모델의 hidden states를 활용하여 특징 추출
train_features, train_labels = extract_features(train_dataloader, model)
test_features, test_labels = extract_features(test_dataloader, model)

# Random Forest Classifier에 입력으로 사용할 수 있도록 특성 배열을 2D로 변환
train_features_2d = train_features.reshape(train_features.shape[0], -1)
test_features_2d = test_features.reshape(test_features.shape[0], -1)

# 추가된 부분: hidden states를 불러와서 새 모델 학습에 사용
if os.path.exists(save_hidden_states_path):
    train_features = np.load(save_hidden_states_path)
    print("Hidden states loaded.")
    # 새로운 모델(예: RandomForestClassifier) 학습
    rf_model = RandomForestClassifier(n_estimators=100, random_state=42)
    rf_model.fit(train_features_2d, train_labels)
else:
    print("Hidden states file not found. Run the BERT model to generate hidden states.")
    
# 새로운 모델 평가
accuracy = rf_model.score(test_features_2d, test_labels)
print(f"New Model Accuracy: {accuracy}")

# BERT 모델의 hidden states를 활용하여 복구한 데이터 출력
def recover_data_from_hidden_states(hidden_states, tokenizer):
    recovered_data = []
    for batch_hidden_states in hidden_states:
        for example_hidden_states in batch_hidden_states:
            # 각 토큰의 hidden state 중 가장 높은 값을 선택하여 복원
            max_hidden_state = np.max(example_hidden_states[1:], axis=0)  # 첫 번째 토큰([CLS])은 제외
            token_id = np.argmax(max_hidden_state)
            token_ids_list = [token_id]  # 리스트에 담아 반복 가능한 객체로 만듦
            recovered_text = tokenizer.decode(token_ids_list)
            recovered_data.append(recovered_text)
    return recovered_data

# BERT 모델의 hidden states를 활용하여 복구한 데이터 출력
print("Recovered Data from Hidden States:")
train_recovered_data = recover_data_from_hidden_states(train_features, tokenizer)
test_recovered_data = recover_data_from_hidden_states(test_features, tokenizer)

print("Train Recovered Data:")
print(train_recovered_data[:1000])  # 처음 100개의 복구된 데이터 출력
print("Test Recovered Data:")
print(test_recovered_data[:1000])  # 처음 100개의 복구된 데이터 출력

Some weights of CustomBertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded.
Hidden states loaded.
New Model Accuracy: 1.0
Recovered Data from Hidden States:
Train Recovered Data:
['[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', 

In [12]:
import os
import pandas as pd
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import TensorDataset, DataLoader, Dataset
from torch.nn import CrossEntropyLoss
import torch.optim as optim
import torch.nn as nn
from sklearn.ensemble import RandomForestClassifier
import time


# Step 1: 데이터 로드 및 전처리
data_A = pd.read_csv("filtered_patients.csv")
data_B = pd.read_csv("filtered_conditions.csv")
data_merged = pd.merge(data_A, data_B, on="ID", how="inner")

# print(data_merged)

selected_columns = ['ID', 'BIRTHDATE', 'GENDER', 'ADDRESS', 'DESCRIPTION']
def classify_covid19(description):
    if "COVID-19" in description:
        return 1
    elif "Suspected COVID-19" in description:
        return 1
    else:
        return 0

data_processed = data_merged[selected_columns].copy()
data_processed['GENDER'] = LabelEncoder().fit_transform(data_processed['GENDER'])
data_processed['COVID-19'] = data_processed['DESCRIPTION'].apply(classify_covid19)

# Step 2: 데이터 분할
X_train, X_test, y_train, y_test = train_test_split(data_processed['DESCRIPTION'], data_processed['COVID-19'], test_size=0.2, random_state=42)

print("X_train\n", X_train)
print("y_train\n", y_train)

# 테스트 데이터셋의 라벨 분포 확인
# print(y_test.value_counts())

# Step 3: BERT 모델 학습 및 평가
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def tokenize_data(data):
    texts = [' '.join(map(str, [d])) for d in data]
    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
    return inputs

# 데이터셋 토큰화
train_inputs = tokenize_data(X_train)
test_inputs = tokenize_data(X_test)

# 레이블 데이터 변환
label_encoder = LabelEncoder()
y_train = label_encoder.fit_transform(y_train)
y_test = label_encoder.transform(y_test)

# Tensor 생성
train_labels = torch.tensor(y_train, dtype=torch.long)
test_labels = torch.tensor(y_test, dtype=torch.long)

# DataLoader 생성
train_batch_size = 64
test_batch_size = 64

train_dataset = TensorDataset(train_inputs['input_ids'], train_inputs['attention_mask'], train_labels)
test_dataset = TensorDataset(test_inputs['input_ids'], test_inputs['attention_mask'], test_labels)
train_dataloader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

class CustomBertForSequenceClassification(BertForSequenceClassification):
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        labels=None,
        output_hidden_states=True
    ):
        outputs = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            labels=labels,
            output_hidden_states=output_hidden_states
        )
        logits = outputs.logits
        hidden_states = outputs.hidden_states[-1]
        loss = outputs.loss
        return logits, loss, hidden_states

# 모델 생성 또는 불러오기
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = 'bert_model.pth'
save_hidden_states_path = 'hidden_states.npy'  # 추가된 부분: hidden states를 저장할 경로

if os.path.exists(model_path):
    model = CustomBertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
    model.load_state_dict(torch.load(model_path))
    model.to(device)  # 모델을 GPU 또는 CPU로 이동
    print("Model loaded.")
else:
    model = CustomBertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(torch.cuda.is_available())
    model.to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    criterion = CrossEntropyLoss()

    # BERT 모델 학습
    start_time = time.time()
    model.train()
    for epoch in range(3):
        total_loss = 0
        for batch_idx, batch in enumerate(train_dataloader):
            optimizer.zero_grad()
            input_ids, attention_mask, labels = batch
            input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
            logits, loss, _ = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            total_loss += loss.item()
            loss.backward()
            optimizer.step()
            print(f"Epoch {epoch+1}, Batch {batch_idx+1}/{len(train_dataloader)}, Loss: {loss.item()}")  # 각 배치의 손실값 출력
        average_loss = total_loss / (batch_idx + 1)  # 배치의 수로 나눠줌
        print(f"Epoch {epoch+1} Average Loss: {average_loss}")

    end_time = time.time()
    print("Training complete.")
    print(f"Training time: {end_time - start_time} seconds")

    # 모델 저장
    torch.save(model.state_dict(), model_path)
    
     # 추가된 부분: hidden states 저장
    np.save(save_hidden_states_path, train_features)  # 훈련 데이터셋의 hidden states를 저장

    # BERT 모델 평가
    model.eval()
    correct_predictions = 0
    total_predictions = 0
    with torch.no_grad():
        for batch in test_dataloader:
            input_ids, attention_mask, labels = batch
            input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
            logits, _, _ = model(input_ids=input_ids, attention_mask=attention_mask)
            predictions = torch.argmax(logits, dim=1)
            correct_predictions += torch.sum(predictions == labels).item()
            total_predictions += len(labels)
    accuracy = correct_predictions / total_predictions
    print(f"BERT Model Accuracy: {accuracy}")



# 새로운 모델 클래스 정의
class CustomRecoveryModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(CustomRecoveryModel, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# 데이터셋 클래스 정의
class CustomDataset(Dataset):
    def __init__(self, features, labels):
        self.features = torch.tensor(features, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)
    
    def __len__(self):
        return len(self.features)
    
    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]

# 데이터 준비
train_dataset = CustomDataset(train_features_2d, train_labels)
test_dataset = CustomDataset(test_features_2d, test_labels)

# 데이터 로더 설정
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# 모델, 손실 함수, 옵티마이저 정의
model = CustomRecoveryModel(input_size=train_features_2d.shape[1], hidden_size=64, output_size=1)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 모델 학습
num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels.view(-1, 1))
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {running_loss / len(train_loader)}")

# 테스트 데이터로 모델 평가
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        predicted = torch.round(torch.sigmoid(outputs))
        total += labels.size(0)
        correct += (predicted == labels.view(-1, 1)).sum().item()

accuracy = correct / total
print(f"Accuracy: {accuracy}")

# 복원된 값을 출력하는 코드 추가
recovered_data = []

model.eval()
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        outputs = model(inputs)
        recovered_data_batch = torch.round(torch.sigmoid(outputs))
        recovered_data.extend(recovered_data_batch.cpu().numpy().flatten().tolist())

print("Recovered Data:")
print(recovered_data[:1000])  # 처음 10개의 복원된 데이터 출력


X_train
 54696                 Loss of taste (finding)
41913                       First degree burn
74118                 Loss of taste (finding)
57549          Respiratory distress (finding)
89528                      Suspected COVID-19
                         ...                 
76820                         Chill (finding)
110268    Acute pulmonary embolism (disorder)
103694                      Fatigue (finding)
860                        Suspected COVID-19
15795                                COVID-19
Name: DESCRIPTION, Length: 91635, dtype: object
y_train
 54696     0
41913     0
74118     0
57549     0
89528     1
         ..
76820     0
110268    0
103694    0
860       1
15795     1
Name: COVID-19, Length: 91635, dtype: int64


Some weights of CustomBertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


True
Epoch 1, Batch 1/1432, Loss: 0.9207198619842529
Epoch 1, Batch 2/1432, Loss: 0.8558141589164734
Epoch 1, Batch 3/1432, Loss: 0.7800703048706055
Epoch 1, Batch 4/1432, Loss: 0.7023741602897644
Epoch 1, Batch 5/1432, Loss: 0.6612610816955566
Epoch 1, Batch 6/1432, Loss: 0.5973271131515503
Epoch 1, Batch 7/1432, Loss: 0.575785219669342
Epoch 1, Batch 8/1432, Loss: 0.588496744632721
Epoch 1, Batch 9/1432, Loss: 0.5424282550811768
Epoch 1, Batch 10/1432, Loss: 0.5458652973175049
Epoch 1, Batch 11/1432, Loss: 0.5472856163978577
Epoch 1, Batch 12/1432, Loss: 0.513495922088623
Epoch 1, Batch 13/1432, Loss: 0.5053238272666931
Epoch 1, Batch 14/1432, Loss: 0.4773605465888977
Epoch 1, Batch 15/1432, Loss: 0.47765907645225525
Epoch 1, Batch 16/1432, Loss: 0.4523170590400696
Epoch 1, Batch 17/1432, Loss: 0.44228464365005493
Epoch 1, Batch 18/1432, Loss: 0.4149090647697449
Epoch 1, Batch 19/1432, Loss: 0.3964105248451233
Epoch 1, Batch 20/1432, Loss: 0.3746567666530609
Epoch 1, Batch 21/1432, L

Epoch 1, Batch 163/1432, Loss: 0.0036072521470487118
Epoch 1, Batch 164/1432, Loss: 0.0033789852168411016
Epoch 1, Batch 165/1432, Loss: 0.0033729197457432747
Epoch 1, Batch 166/1432, Loss: 0.0032653675880283117
Epoch 1, Batch 167/1432, Loss: 0.0031512302812188864
Epoch 1, Batch 168/1432, Loss: 0.003301331540569663
Epoch 1, Batch 169/1432, Loss: 0.003323799464851618
Epoch 1, Batch 170/1432, Loss: 0.003357558511197567
Epoch 1, Batch 171/1432, Loss: 0.0036757327616214752
Epoch 1, Batch 172/1432, Loss: 0.003202756866812706
Epoch 1, Batch 173/1432, Loss: 0.002844445873051882
Epoch 1, Batch 174/1432, Loss: 0.00297245429828763
Epoch 1, Batch 175/1432, Loss: 0.0029794189613312483
Epoch 1, Batch 176/1432, Loss: 0.003118538996204734
Epoch 1, Batch 177/1432, Loss: 0.0028570308350026608
Epoch 1, Batch 178/1432, Loss: 0.0031137133482843637
Epoch 1, Batch 179/1432, Loss: 0.003085720119997859
Epoch 1, Batch 180/1432, Loss: 0.00271443254314363
Epoch 1, Batch 181/1432, Loss: 0.0027741033118218184
Epoc

Epoch 1, Batch 319/1432, Loss: 0.0012561645125970244
Epoch 1, Batch 320/1432, Loss: 0.0013107002014294267
Epoch 1, Batch 321/1432, Loss: 0.0012233529705554247
Epoch 1, Batch 322/1432, Loss: 0.0012690563453361392
Epoch 1, Batch 323/1432, Loss: 0.0012984576169401407
Epoch 1, Batch 324/1432, Loss: 0.00123692792840302
Epoch 1, Batch 325/1432, Loss: 0.001317961374297738
Epoch 1, Batch 326/1432, Loss: 0.0012753861956298351
Epoch 1, Batch 327/1432, Loss: 0.0012396989623084664
Epoch 1, Batch 328/1432, Loss: 0.0011902584228664637
Epoch 1, Batch 329/1432, Loss: 0.0012286477722227573
Epoch 1, Batch 330/1432, Loss: 0.0010906162206083536
Epoch 1, Batch 331/1432, Loss: 0.0011960135307163
Epoch 1, Batch 332/1432, Loss: 0.001207471708767116
Epoch 1, Batch 333/1432, Loss: 0.0011885161511600018
Epoch 1, Batch 334/1432, Loss: 0.0012086504139006138
Epoch 1, Batch 335/1432, Loss: 0.0011938484385609627
Epoch 1, Batch 336/1432, Loss: 0.0012079435400664806
Epoch 1, Batch 337/1432, Loss: 0.0011433230247348547


Epoch 1, Batch 475/1432, Loss: 0.0007039513438940048
Epoch 1, Batch 476/1432, Loss: 0.0007289023487828672
Epoch 1, Batch 477/1432, Loss: 0.0007235986995510757
Epoch 1, Batch 478/1432, Loss: 0.0007313599926419556
Epoch 1, Batch 479/1432, Loss: 0.0006485484773293138
Epoch 1, Batch 480/1432, Loss: 0.0007302596932277083
Epoch 1, Batch 481/1432, Loss: 0.0007332886452786624
Epoch 1, Batch 482/1432, Loss: 0.0006847494514659047
Epoch 1, Batch 483/1432, Loss: 0.0007328743813559413
Epoch 1, Batch 484/1432, Loss: 0.0006756462389603257
Epoch 1, Batch 485/1432, Loss: 0.0006918266881257296
Epoch 1, Batch 486/1432, Loss: 0.0007150456658564508
Epoch 1, Batch 487/1432, Loss: 0.0006761607364751399
Epoch 1, Batch 488/1432, Loss: 0.0006982507766224444
Epoch 1, Batch 489/1432, Loss: 0.0006982570048421621
Epoch 1, Batch 490/1432, Loss: 0.0006717231590300798
Epoch 1, Batch 491/1432, Loss: 0.0007018507458269596
Epoch 1, Batch 492/1432, Loss: 0.000661441357806325
Epoch 1, Batch 493/1432, Loss: 0.00067718478385

Epoch 1, Batch 631/1432, Loss: 0.0004555986379273236
Epoch 1, Batch 632/1432, Loss: 0.0004605460271704942
Epoch 1, Batch 633/1432, Loss: 0.0004619652754627168
Epoch 1, Batch 634/1432, Loss: 0.0004752526292577386
Epoch 1, Batch 635/1432, Loss: 0.0004675831878557801
Epoch 1, Batch 636/1432, Loss: 0.0004613881465047598
Epoch 1, Batch 637/1432, Loss: 0.00047710625221952796
Epoch 1, Batch 638/1432, Loss: 0.0004569443117361516
Epoch 1, Batch 639/1432, Loss: 0.00045268930261954665
Epoch 1, Batch 640/1432, Loss: 0.0004519833892118186
Epoch 1, Batch 641/1432, Loss: 0.00044500481453724205
Epoch 1, Batch 642/1432, Loss: 0.00045195993152447045
Epoch 1, Batch 643/1432, Loss: 0.00045162165770307183
Epoch 1, Batch 644/1432, Loss: 0.0004761026066262275
Epoch 1, Batch 645/1432, Loss: 0.0004511607112362981
Epoch 1, Batch 646/1432, Loss: 0.0004375193966552615
Epoch 1, Batch 647/1432, Loss: 0.0004574938502628356
Epoch 1, Batch 648/1432, Loss: 0.00042481670971028507
Epoch 1, Batch 649/1432, Loss: 0.0004439

Epoch 1, Batch 785/1432, Loss: 0.0003236084012314677
Epoch 1, Batch 786/1432, Loss: 0.00033854416687972844
Epoch 1, Batch 787/1432, Loss: 0.0003429929492995143
Epoch 1, Batch 788/1432, Loss: 0.00034854403929784894
Epoch 1, Batch 789/1432, Loss: 0.0003263613616582006
Epoch 1, Batch 790/1432, Loss: 0.00032458986970596015
Epoch 1, Batch 791/1432, Loss: 0.00033410798641853034
Epoch 1, Batch 792/1432, Loss: 0.0003324093122500926
Epoch 1, Batch 793/1432, Loss: 0.00034275854704901576
Epoch 1, Batch 794/1432, Loss: 0.00033868695027194917
Epoch 1, Batch 795/1432, Loss: 0.00033929647179320455
Epoch 1, Batch 796/1432, Loss: 0.0003270005399826914
Epoch 1, Batch 797/1432, Loss: 0.0003311522596050054
Epoch 1, Batch 798/1432, Loss: 0.00034080224577337503
Epoch 1, Batch 799/1432, Loss: 0.00034203732502646744
Epoch 1, Batch 800/1432, Loss: 0.0003091645485255867
Epoch 1, Batch 801/1432, Loss: 0.0003204428940080106
Epoch 1, Batch 802/1432, Loss: 0.00034530769335106015
Epoch 1, Batch 803/1432, Loss: 0.000

Epoch 1, Batch 939/1432, Loss: 0.0002509279001969844
Epoch 1, Batch 940/1432, Loss: 0.000247662392212078
Epoch 1, Batch 941/1432, Loss: 0.00024569578818045557
Epoch 1, Batch 942/1432, Loss: 0.0002548981283325702
Epoch 1, Batch 943/1432, Loss: 0.00023710043751634657
Epoch 1, Batch 944/1432, Loss: 0.0002454589703120291
Epoch 1, Batch 945/1432, Loss: 0.00025510889827273786
Epoch 1, Batch 946/1432, Loss: 0.00025028796517290175
Epoch 1, Batch 947/1432, Loss: 0.0002573309757281095
Epoch 1, Batch 948/1432, Loss: 0.00025969804846681654
Epoch 1, Batch 949/1432, Loss: 0.00024671858409419656
Epoch 1, Batch 950/1432, Loss: 0.00025731511414051056
Epoch 1, Batch 951/1432, Loss: 0.0002526042517274618
Epoch 1, Batch 952/1432, Loss: 0.0002436563663650304
Epoch 1, Batch 953/1432, Loss: 0.0002593838726170361
Epoch 1, Batch 954/1432, Loss: 0.00023401588259730488
Epoch 1, Batch 955/1432, Loss: 0.0002410377492196858
Epoch 1, Batch 956/1432, Loss: 0.00024926906917244196
Epoch 1, Batch 957/1432, Loss: 0.00025

Epoch 1, Batch 1091/1432, Loss: 0.00020402896916493773
Epoch 1, Batch 1092/1432, Loss: 0.00020285937353037298
Epoch 1, Batch 1093/1432, Loss: 0.00022717170941177756
Epoch 1, Batch 1094/1432, Loss: 0.00020921304530929774
Epoch 1, Batch 1095/1432, Loss: 0.00021047367772553116
Epoch 1, Batch 1096/1432, Loss: 0.00019177347712684423
Epoch 1, Batch 1097/1432, Loss: 0.0002067830937448889
Epoch 1, Batch 1098/1432, Loss: 0.00020010102889500558
Epoch 1, Batch 1099/1432, Loss: 0.000208024779567495
Epoch 1, Batch 1100/1432, Loss: 0.00020445183326955885
Epoch 1, Batch 1101/1432, Loss: 0.0001989447046071291
Epoch 1, Batch 1102/1432, Loss: 0.0001952088496182114
Epoch 1, Batch 1103/1432, Loss: 0.00020252710964996368
Epoch 1, Batch 1104/1432, Loss: 0.00020359842164907604
Epoch 1, Batch 1105/1432, Loss: 0.0001970769080799073
Epoch 1, Batch 1106/1432, Loss: 0.00019153089669998735
Epoch 1, Batch 1107/1432, Loss: 0.00018614767759572715
Epoch 1, Batch 1108/1432, Loss: 0.000206362281460315
Epoch 1, Batch 110

Epoch 1, Batch 1241/1432, Loss: 0.00017242970352526754
Epoch 1, Batch 1242/1432, Loss: 0.00017026007117237896
Epoch 1, Batch 1243/1432, Loss: 0.00016359155415557325
Epoch 1, Batch 1244/1432, Loss: 0.00017019321967381984
Epoch 1, Batch 1245/1432, Loss: 0.0001625911536393687
Epoch 1, Batch 1246/1432, Loss: 0.0001687526673777029
Epoch 1, Batch 1247/1432, Loss: 0.0001681967987678945
Epoch 1, Batch 1248/1432, Loss: 0.00015788132441230118
Epoch 1, Batch 1249/1432, Loss: 0.00017121934797614813
Epoch 1, Batch 1250/1432, Loss: 0.00017097836825996637
Epoch 1, Batch 1251/1432, Loss: 0.0001641422713873908
Epoch 1, Batch 1252/1432, Loss: 0.0001610528997844085
Epoch 1, Batch 1253/1432, Loss: 0.00015904338215477765
Epoch 1, Batch 1254/1432, Loss: 0.00015906200860626996
Epoch 1, Batch 1255/1432, Loss: 0.00017033024050761014
Epoch 1, Batch 1256/1432, Loss: 0.00016789273649919778
Epoch 1, Batch 1257/1432, Loss: 0.00015825754962861538
Epoch 1, Batch 1258/1432, Loss: 0.00015480269212275743
Epoch 1, Batch 

Epoch 1, Batch 1391/1432, Loss: 0.00014095252845436335
Epoch 1, Batch 1392/1432, Loss: 0.00013902685896027833
Epoch 1, Batch 1393/1432, Loss: 0.00013927824329584837
Epoch 1, Batch 1394/1432, Loss: 0.0001351921382592991
Epoch 1, Batch 1395/1432, Loss: 0.00013486067473422736
Epoch 1, Batch 1396/1432, Loss: 0.00015239852655213326
Epoch 1, Batch 1397/1432, Loss: 0.00013631337787956
Epoch 1, Batch 1398/1432, Loss: 0.00013563352695200592
Epoch 1, Batch 1399/1432, Loss: 0.00013760027650278062
Epoch 1, Batch 1400/1432, Loss: 0.0001382722402922809
Epoch 1, Batch 1401/1432, Loss: 0.00014054244093131274
Epoch 1, Batch 1402/1432, Loss: 0.00013389429659582675
Epoch 1, Batch 1403/1432, Loss: 0.00013915904855821282
Epoch 1, Batch 1404/1432, Loss: 0.00013895423035137355
Epoch 1, Batch 1405/1432, Loss: 0.00013482900976669043
Epoch 1, Batch 1406/1432, Loss: 0.00014068424934521317
Epoch 1, Batch 1407/1432, Loss: 0.0001290705695282668
Epoch 1, Batch 1408/1432, Loss: 0.00014064692368265241
Epoch 1, Batch 1

Epoch 2, Batch 113/1432, Loss: 0.0001203898573294282
Epoch 2, Batch 114/1432, Loss: 0.00011581020226003602
Epoch 2, Batch 115/1432, Loss: 0.00011257168807787821
Epoch 2, Batch 116/1432, Loss: 0.00011784393427660689
Epoch 2, Batch 117/1432, Loss: 0.00011591074144234881
Epoch 2, Batch 118/1432, Loss: 0.00011249688395764679
Epoch 2, Batch 119/1432, Loss: 0.00010988780559273437
Epoch 2, Batch 120/1432, Loss: 0.00012191865243948996
Epoch 2, Batch 121/1432, Loss: 0.00011554172670003027
Epoch 2, Batch 122/1432, Loss: 0.00011332766734994948
Epoch 2, Batch 123/1432, Loss: 0.00011187506606802344
Epoch 2, Batch 124/1432, Loss: 0.0001138526204158552
Epoch 2, Batch 125/1432, Loss: 0.00011153602099511772
Epoch 2, Batch 126/1432, Loss: 0.00011099025141447783
Epoch 2, Batch 127/1432, Loss: 0.0001156497819465585
Epoch 2, Batch 128/1432, Loss: 0.00010858221503440291
Epoch 2, Batch 129/1432, Loss: 0.00010946313705062494
Epoch 2, Batch 130/1432, Loss: 0.00010877787281060591
Epoch 2, Batch 131/1432, Loss: 

Epoch 2, Batch 267/1432, Loss: 9.664958633948117e-05
Epoch 2, Batch 268/1432, Loss: 0.00010136155469808728
Epoch 2, Batch 269/1432, Loss: 9.379447146784514e-05
Epoch 2, Batch 270/1432, Loss: 9.771849727258086e-05
Epoch 2, Batch 271/1432, Loss: 0.00010129630391020328
Epoch 2, Batch 272/1432, Loss: 9.786761802388355e-05
Epoch 2, Batch 273/1432, Loss: 0.00010691488569136709
Epoch 2, Batch 274/1432, Loss: 9.6885982202366e-05
Epoch 2, Batch 275/1432, Loss: 9.55376963247545e-05
Epoch 2, Batch 276/1432, Loss: 9.501994645688683e-05
Epoch 2, Batch 277/1432, Loss: 0.00010288493649568409
Epoch 2, Batch 278/1432, Loss: 9.375160880153999e-05
Epoch 2, Batch 279/1432, Loss: 9.645202226238325e-05
Epoch 2, Batch 280/1432, Loss: 0.00010536570334807038
Epoch 2, Batch 281/1432, Loss: 9.817680256674066e-05
Epoch 2, Batch 282/1432, Loss: 9.545195644022897e-05
Epoch 2, Batch 283/1432, Loss: 0.00010014912550104782
Epoch 2, Batch 284/1432, Loss: 9.669246355770156e-05
Epoch 2, Batch 285/1432, Loss: 9.7740921773

Epoch 2, Batch 423/1432, Loss: 8.37296320241876e-05
Epoch 2, Batch 424/1432, Loss: 8.808595885057002e-05
Epoch 2, Batch 425/1432, Loss: 8.834660548018292e-05
Epoch 2, Batch 426/1432, Loss: 8.696289296494797e-05
Epoch 2, Batch 427/1432, Loss: 8.187277126125991e-05
Epoch 2, Batch 428/1432, Loss: 8.830198203213513e-05
Epoch 2, Batch 429/1432, Loss: 8.667611837154254e-05
Epoch 2, Batch 430/1432, Loss: 8.136992255458608e-05
Epoch 2, Batch 431/1432, Loss: 7.828001253074035e-05
Epoch 2, Batch 432/1432, Loss: 8.323597285198048e-05
Epoch 2, Batch 433/1432, Loss: 8.214096305891871e-05
Epoch 2, Batch 434/1432, Loss: 8.146114851115271e-05
Epoch 2, Batch 435/1432, Loss: 8.179265569197014e-05
Epoch 2, Batch 436/1432, Loss: 8.181122393580154e-05
Epoch 2, Batch 437/1432, Loss: 8.181856537703425e-05
Epoch 2, Batch 438/1432, Loss: 8.04069495643489e-05
Epoch 2, Batch 439/1432, Loss: 8.475399226881564e-05
Epoch 2, Batch 440/1432, Loss: 8.49737916723825e-05
Epoch 2, Batch 441/1432, Loss: 8.640793384984136e

Epoch 2, Batch 579/1432, Loss: 7.318412099266425e-05
Epoch 2, Batch 580/1432, Loss: 7.303708116523921e-05
Epoch 2, Batch 581/1432, Loss: 7.180224201874807e-05
Epoch 2, Batch 582/1432, Loss: 6.732666224706918e-05
Epoch 2, Batch 583/1432, Loss: 7.184322021203116e-05
Epoch 2, Batch 584/1432, Loss: 7.612697663716972e-05
Epoch 2, Batch 585/1432, Loss: 7.45252036722377e-05
Epoch 2, Batch 586/1432, Loss: 7.180411193985492e-05
Epoch 2, Batch 587/1432, Loss: 7.416947482852265e-05
Epoch 2, Batch 588/1432, Loss: 7.48883539927192e-05
Epoch 2, Batch 589/1432, Loss: 7.4402239988558e-05
Epoch 2, Batch 590/1432, Loss: 6.99602605891414e-05
Epoch 2, Batch 591/1432, Loss: 7.084303797455505e-05
Epoch 2, Batch 592/1432, Loss: 7.222128624562174e-05
Epoch 2, Batch 593/1432, Loss: 7.334068504860625e-05
Epoch 2, Batch 594/1432, Loss: 7.165136048570275e-05
Epoch 2, Batch 595/1432, Loss: 7.314509275602177e-05
Epoch 2, Batch 596/1432, Loss: 7.076854672050104e-05
Epoch 2, Batch 597/1432, Loss: 7.695947715546936e-0

Epoch 2, Batch 735/1432, Loss: 6.453470996348187e-05
Epoch 2, Batch 736/1432, Loss: 6.429072527680546e-05
Epoch 2, Batch 737/1432, Loss: 6.378405669238418e-05
Epoch 2, Batch 738/1432, Loss: 6.519781891256571e-05
Epoch 2, Batch 739/1432, Loss: 6.17167170275934e-05
Epoch 2, Batch 740/1432, Loss: 6.143177597550675e-05
Epoch 2, Batch 741/1432, Loss: 6.393127114279196e-05
Epoch 2, Batch 742/1432, Loss: 6.327567825792357e-05
Epoch 2, Batch 743/1432, Loss: 6.188623228808865e-05
Epoch 2, Batch 744/1432, Loss: 6.383621075656265e-05
Epoch 2, Batch 745/1432, Loss: 6.44360261503607e-05
Epoch 2, Batch 746/1432, Loss: 6.708629371132702e-05
Epoch 2, Batch 747/1432, Loss: 6.179123738547787e-05
Epoch 2, Batch 748/1432, Loss: 6.085442510084249e-05
Epoch 2, Batch 749/1432, Loss: 6.393122748704627e-05
Epoch 2, Batch 750/1432, Loss: 6.0539608966792e-05
Epoch 2, Batch 751/1432, Loss: 6.307449075393379e-05
Epoch 2, Batch 752/1432, Loss: 6.27485933364369e-05
Epoch 2, Batch 753/1432, Loss: 6.592968566110358e-0

Epoch 2, Batch 891/1432, Loss: 5.713865175493993e-05
Epoch 2, Batch 892/1432, Loss: 5.464841524371877e-05
Epoch 2, Batch 893/1432, Loss: 5.849263470736332e-05
Epoch 2, Batch 894/1432, Loss: 5.425354538601823e-05
Epoch 2, Batch 895/1432, Loss: 5.396491542342119e-05
Epoch 2, Batch 896/1432, Loss: 5.7172153901774436e-05
Epoch 2, Batch 897/1432, Loss: 5.491658885148354e-05
Epoch 2, Batch 898/1432, Loss: 5.684994175680913e-05
Epoch 2, Batch 899/1432, Loss: 5.540835263673216e-05
Epoch 2, Batch 900/1432, Loss: 5.726148810936138e-05
Epoch 2, Batch 901/1432, Loss: 5.249719833955169e-05
Epoch 2, Batch 902/1432, Loss: 5.509726543095894e-05
Epoch 2, Batch 903/1432, Loss: 5.3836389270145446e-05
Epoch 2, Batch 904/1432, Loss: 5.649418380926363e-05
Epoch 2, Batch 905/1432, Loss: 5.521649291040376e-05
Epoch 2, Batch 906/1432, Loss: 5.7192657550331205e-05
Epoch 2, Batch 907/1432, Loss: 5.397422501118854e-05
Epoch 2, Batch 908/1432, Loss: 5.451805191114545e-05
Epoch 2, Batch 909/1432, Loss: 5.7183322496

Epoch 2, Batch 1045/1432, Loss: 4.8803765821503475e-05
Epoch 2, Batch 1046/1432, Loss: 4.917255137115717e-05
Epoch 2, Batch 1047/1432, Loss: 4.854301369050518e-05
Epoch 2, Batch 1048/1432, Loss: 5.004046397516504e-05
Epoch 2, Batch 1049/1432, Loss: 4.9932492402149364e-05
Epoch 2, Batch 1050/1432, Loss: 4.8706908273743466e-05
Epoch 2, Batch 1051/1432, Loss: 4.839399844058789e-05
Epoch 2, Batch 1052/1432, Loss: 4.7235513193299994e-05
Epoch 2, Batch 1053/1432, Loss: 5.000883174943738e-05
Epoch 2, Batch 1054/1432, Loss: 5.0899136113002896e-05
Epoch 2, Batch 1055/1432, Loss: 4.898817132925615e-05
Epoch 2, Batch 1056/1432, Loss: 4.622601773007773e-05
Epoch 2, Batch 1057/1432, Loss: 4.86286953673698e-05
Epoch 2, Batch 1058/1432, Loss: 4.968657231074758e-05
Epoch 2, Batch 1059/1432, Loss: 4.919489583699033e-05
Epoch 2, Batch 1060/1432, Loss: 5.020813114242628e-05
Epoch 2, Batch 1061/1432, Loss: 4.7950688895070925e-05
Epoch 2, Batch 1062/1432, Loss: 4.860074841417372e-05
Epoch 2, Batch 1063/143

Epoch 2, Batch 1197/1432, Loss: 4.412690759636462e-05
Epoch 2, Batch 1198/1432, Loss: 4.1994295315817e-05
Epoch 2, Batch 1199/1432, Loss: 4.530588194029406e-05
Epoch 2, Batch 1200/1432, Loss: 4.4564600102603436e-05
Epoch 2, Batch 1201/1432, Loss: 4.128836008021608e-05
Epoch 2, Batch 1202/1432, Loss: 4.553500184556469e-05
Epoch 2, Batch 1203/1432, Loss: 4.4672629883280024e-05
Epoch 2, Batch 1204/1432, Loss: 4.287712727091275e-05
Epoch 2, Batch 1205/1432, Loss: 4.5507047616411e-05
Epoch 2, Batch 1206/1432, Loss: 4.4471460569184273e-05
Epoch 2, Batch 1207/1432, Loss: 4.140198143431917e-05
Epoch 2, Batch 1208/1432, Loss: 4.46558442490641e-05
Epoch 2, Batch 1209/1432, Loss: 4.229039768688381e-05
Epoch 2, Batch 1210/1432, Loss: 4.3992800783598796e-05
Epoch 2, Batch 1211/1432, Loss: 4.4775053538614884e-05
Epoch 2, Batch 1212/1432, Loss: 4.2752319131977856e-05
Epoch 2, Batch 1213/1432, Loss: 4.103133323951624e-05
Epoch 2, Batch 1214/1432, Loss: 4.318628634791821e-05
Epoch 2, Batch 1215/1432, L

Epoch 2, Batch 1349/1432, Loss: 4.015591548522934e-05
Epoch 2, Batch 1350/1432, Loss: 3.9451868360629305e-05
Epoch 2, Batch 1351/1432, Loss: 3.798788020503707e-05
Epoch 2, Batch 1352/1432, Loss: 3.996035593445413e-05
Epoch 2, Batch 1353/1432, Loss: 3.6514604289550334e-05
Epoch 2, Batch 1354/1432, Loss: 3.80847486667335e-05
Epoch 2, Batch 1355/1432, Loss: 3.815737727563828e-05
Epoch 2, Batch 1356/1432, Loss: 3.7296853406587616e-05
Epoch 2, Batch 1357/1432, Loss: 3.7091995181981474e-05
Epoch 2, Batch 1358/1432, Loss: 3.9878388633951545e-05
Epoch 2, Batch 1359/1432, Loss: 3.905139965354465e-05
Epoch 2, Batch 1360/1432, Loss: 4.135720882914029e-05
Epoch 2, Batch 1361/1432, Loss: 4.012235876871273e-05
Epoch 2, Batch 1362/1432, Loss: 3.956734508392401e-05
Epoch 2, Batch 1363/1432, Loss: 3.812571958405897e-05
Epoch 2, Batch 1364/1432, Loss: 3.819090488832444e-05
Epoch 2, Batch 1365/1432, Loss: 3.8419988413807005e-05
Epoch 2, Batch 1366/1432, Loss: 3.862674566335045e-05
Epoch 2, Batch 1367/143

Epoch 3, Batch 71/1432, Loss: 3.461475353105925e-05
Epoch 3, Batch 72/1432, Loss: 3.355495573487133e-05
Epoch 3, Batch 73/1432, Loss: 3.4169603168265894e-05
Epoch 3, Batch 74/1432, Loss: 3.48345456586685e-05
Epoch 3, Batch 75/1432, Loss: 3.5119523090543225e-05
Epoch 3, Batch 76/1432, Loss: 3.5605648008640856e-05
Epoch 3, Batch 77/1432, Loss: 3.390883648535237e-05
Epoch 3, Batch 78/1432, Loss: 3.554232171154581e-05
Epoch 3, Batch 79/1432, Loss: 3.4674365451792255e-05
Epoch 3, Batch 80/1432, Loss: 3.4367043554084376e-05
Epoch 3, Batch 81/1432, Loss: 3.3778469514800236e-05
Epoch 3, Batch 82/1432, Loss: 3.51809649146162e-05
Epoch 3, Batch 83/1432, Loss: 3.584965452319011e-05
Epoch 3, Batch 84/1432, Loss: 3.4538406907813624e-05
Epoch 3, Batch 85/1432, Loss: 3.464457768131979e-05
Epoch 3, Batch 86/1432, Loss: 3.638606722233817e-05
Epoch 3, Batch 87/1432, Loss: 3.3094904210884124e-05
Epoch 3, Batch 88/1432, Loss: 3.149680560454726e-05
Epoch 3, Batch 89/1432, Loss: 3.465574627625756e-05
Epoch 

Epoch 3, Batch 227/1432, Loss: 3.07871559925843e-05
Epoch 3, Batch 228/1432, Loss: 3.097341686952859e-05
Epoch 3, Batch 229/1432, Loss: 2.9712442483287305e-05
Epoch 3, Batch 230/1432, Loss: 3.096037107752636e-05
Epoch 3, Batch 231/1432, Loss: 3.030287189176306e-05
Epoch 3, Batch 232/1432, Loss: 3.1053506972966716e-05
Epoch 3, Batch 233/1432, Loss: 3.0463055736618116e-05
Epoch 3, Batch 234/1432, Loss: 3.321966505609453e-05
Epoch 3, Batch 235/1432, Loss: 2.9950855605420657e-05
Epoch 3, Batch 236/1432, Loss: 2.977576878038235e-05
Epoch 3, Batch 237/1432, Loss: 3.2495147024746984e-05
Epoch 3, Batch 238/1432, Loss: 3.0176228392519988e-05
Epoch 3, Batch 239/1432, Loss: 3.3193588024005294e-05
Epoch 3, Batch 240/1432, Loss: 3.237216151319444e-05
Epoch 3, Batch 241/1432, Loss: 3.1636489438824356e-05
Epoch 3, Batch 242/1432, Loss: 2.9297085347934626e-05
Epoch 3, Batch 243/1432, Loss: 3.0148295991239138e-05
Epoch 3, Batch 244/1432, Loss: 2.9807448299834505e-05
Epoch 3, Batch 245/1432, Loss: 2.897

Epoch 3, Batch 381/1432, Loss: 2.7851725462824106e-05
Epoch 3, Batch 382/1432, Loss: 2.746243990259245e-05
Epoch 3, Batch 383/1432, Loss: 2.7332062018103898e-05
Epoch 3, Batch 384/1432, Loss: 2.8749482225975953e-05
Epoch 3, Batch 385/1432, Loss: 2.742331889749039e-05
Epoch 3, Batch 386/1432, Loss: 2.723147554206662e-05
Epoch 3, Batch 387/1432, Loss: 2.7669171686284244e-05
Epoch 3, Batch 388/1432, Loss: 2.8051017579855397e-05
Epoch 3, Batch 389/1432, Loss: 2.9518732844735496e-05
Epoch 3, Batch 390/1432, Loss: 2.6894342227024026e-05
Epoch 3, Batch 391/1432, Loss: 2.606549605843611e-05
Epoch 3, Batch 392/1432, Loss: 2.8019359888276085e-05
Epoch 3, Batch 393/1432, Loss: 2.8116202884120867e-05
Epoch 3, Batch 394/1432, Loss: 2.696885530895088e-05
Epoch 3, Batch 395/1432, Loss: 2.7620764740277082e-05
Epoch 3, Batch 396/1432, Loss: 2.7631938792183064e-05
Epoch 3, Batch 397/1432, Loss: 2.6108338715857826e-05
Epoch 3, Batch 398/1432, Loss: 2.633184521982912e-05
Epoch 3, Batch 399/1432, Loss: 2.8

Epoch 3, Batch 535/1432, Loss: 2.5817775167524815e-05
Epoch 3, Batch 536/1432, Loss: 2.6428691853652708e-05
Epoch 3, Batch 537/1432, Loss: 2.5435942006879486e-05
Epoch 3, Batch 538/1432, Loss: 2.5648278096923605e-05
Epoch 3, Batch 539/1432, Loss: 2.4458075131406076e-05
Epoch 3, Batch 540/1432, Loss: 2.4888340703910217e-05
Epoch 3, Batch 541/1432, Loss: 2.483059506630525e-05
Epoch 3, Batch 542/1432, Loss: 2.63486090261722e-05
Epoch 3, Batch 543/1432, Loss: 2.64454629359534e-05
Epoch 3, Batch 544/1432, Loss: 2.576375482021831e-05
Epoch 3, Batch 545/1432, Loss: 2.5672485207905993e-05
Epoch 3, Batch 546/1432, Loss: 2.372235212533269e-05
Epoch 3, Batch 547/1432, Loss: 2.530928213673178e-05
Epoch 3, Batch 548/1432, Loss: 2.4558657969464548e-05
Epoch 3, Batch 549/1432, Loss: 2.3977525415830314e-05
Epoch 3, Batch 550/1432, Loss: 2.5355839170515537e-05
Epoch 3, Batch 551/1432, Loss: 2.402967948000878e-05
Epoch 3, Batch 552/1432, Loss: 2.7175598006579094e-05
Epoch 3, Batch 553/1432, Loss: 2.4089

Epoch 3, Batch 689/1432, Loss: 2.2429705495596863e-05
Epoch 3, Batch 690/1432, Loss: 2.2690463083563372e-05
Epoch 3, Batch 691/1432, Loss: 2.2437150619225577e-05
Epoch 3, Batch 692/1432, Loss: 2.3144939405028708e-05
Epoch 3, Batch 693/1432, Loss: 2.3046222850098275e-05
Epoch 3, Batch 694/1432, Loss: 2.2152175006340258e-05
Epoch 3, Batch 695/1432, Loss: 2.266252886329312e-05
Epoch 3, Batch 696/1432, Loss: 2.3977525415830314e-05
Epoch 3, Batch 697/1432, Loss: 2.2960539354244247e-05
Epoch 3, Batch 698/1432, Loss: 2.1688390916096978e-05
Epoch 3, Batch 699/1432, Loss: 2.269605829496868e-05
Epoch 3, Batch 700/1432, Loss: 2.245764153485652e-05
Epoch 3, Batch 701/1432, Loss: 2.1673487935913727e-05
Epoch 3, Batch 702/1432, Loss: 2.3372176656266674e-05
Epoch 3, Batch 703/1432, Loss: 2.4418963221251033e-05
Epoch 3, Batch 704/1432, Loss: 2.3098378733266145e-05
Epoch 3, Batch 705/1432, Loss: 2.270723598485347e-05
Epoch 3, Batch 706/1432, Loss: 2.224158015451394e-05
Epoch 3, Batch 707/1432, Loss: 2.

Epoch 3, Batch 843/1432, Loss: 2.0693762053269893e-05
Epoch 3, Batch 844/1432, Loss: 2.0019493604195304e-05
Epoch 3, Batch 845/1432, Loss: 2.0185267203487456e-05
Epoch 3, Batch 846/1432, Loss: 2.0924720956827514e-05
Epoch 3, Batch 847/1432, Loss: 2.013125595112797e-05
Epoch 3, Batch 848/1432, Loss: 2.0639745343942195e-05
Epoch 3, Batch 849/1432, Loss: 2.088374094455503e-05
Epoch 3, Batch 850/1432, Loss: 2.0682589820353314e-05
Epoch 3, Batch 851/1432, Loss: 2.014615347434301e-05
Epoch 3, Batch 852/1432, Loss: 2.0045570636284538e-05
Epoch 3, Batch 853/1432, Loss: 2.0816689357161522e-05
Epoch 3, Batch 854/1432, Loss: 2.0146155293332413e-05
Epoch 3, Batch 855/1432, Loss: 2.1284204194671474e-05
Epoch 3, Batch 856/1432, Loss: 2.116312680300325e-05
Epoch 3, Batch 857/1432, Loss: 2.0149876945652068e-05
Epoch 3, Batch 858/1432, Loss: 2.1513300453079864e-05
Epoch 3, Batch 859/1432, Loss: 1.9572467863326892e-05
Epoch 3, Batch 860/1432, Loss: 2.134753049176652e-05
Epoch 3, Batch 861/1432, Loss: 2.

Epoch 3, Batch 997/1432, Loss: 1.844745384005364e-05
Epoch 3, Batch 998/1432, Loss: 1.934710053319577e-05
Epoch 3, Batch 999/1432, Loss: 2.0025074263685383e-05
Epoch 3, Batch 1000/1432, Loss: 1.9261417037341744e-05
Epoch 3, Batch 1001/1432, Loss: 1.9224164134357125e-05
Epoch 3, Batch 1002/1432, Loss: 1.8197864847024903e-05
Epoch 3, Batch 1003/1432, Loss: 1.9600407540565357e-05
Epoch 3, Batch 1004/1432, Loss: 1.8337563233217224e-05
Epoch 3, Batch 1005/1432, Loss: 1.8903792806668207e-05
Epoch 3, Batch 1006/1432, Loss: 1.8125227143173106e-05
Epoch 3, Batch 1007/1432, Loss: 1.9730778149096295e-05
Epoch 3, Batch 1008/1432, Loss: 1.9669325411086902e-05
Epoch 3, Batch 1009/1432, Loss: 1.8240711142425425e-05
Epoch 3, Batch 1010/1432, Loss: 1.907515434140805e-05
Epoch 3, Batch 1011/1432, Loss: 1.900437564472668e-05
Epoch 3, Batch 1012/1432, Loss: 1.8430697309668176e-05
Epoch 3, Batch 1013/1432, Loss: 1.882742981251795e-05
Epoch 3, Batch 1014/1432, Loss: 1.8060036381939426e-05
Epoch 3, Batch 101

Epoch 3, Batch 1149/1432, Loss: 1.7149217455880716e-05
Epoch 3, Batch 1150/1432, Loss: 1.6461912309750915e-05
Epoch 3, Batch 1151/1432, Loss: 1.7199508874909952e-05
Epoch 3, Batch 1152/1432, Loss: 1.671150494075846e-05
Epoch 3, Batch 1153/1432, Loss: 1.7946415027836338e-05
Epoch 3, Batch 1154/1432, Loss: 1.646563941903878e-05
Epoch 3, Batch 1155/1432, Loss: 1.6979718566290103e-05
Epoch 3, Batch 1156/1432, Loss: 1.831893496273551e-05
Epoch 3, Batch 1157/1432, Loss: 1.6681702618370764e-05
Epoch 3, Batch 1158/1432, Loss: 1.7486343494965695e-05
Epoch 3, Batch 1159/1432, Loss: 1.725352012726944e-05
Epoch 3, Batch 1160/1432, Loss: 1.73652806552127e-05
Epoch 3, Batch 1161/1432, Loss: 1.696854451438412e-05
Epoch 3, Batch 1162/1432, Loss: 1.6808358850539662e-05
Epoch 3, Batch 1163/1432, Loss: 1.7115688024205156e-05
Epoch 3, Batch 1164/1432, Loss: 1.7130589185399003e-05
Epoch 3, Batch 1165/1432, Loss: 1.6996484191622585e-05
Epoch 3, Batch 1166/1432, Loss: 1.5815587175893597e-05
Epoch 3, Batch 11

Epoch 3, Batch 1299/1432, Loss: 1.599440111021977e-05
Epoch 3, Batch 1300/1432, Loss: 1.6191832401091233e-05
Epoch 3, Batch 1301/1432, Loss: 1.469988092139829e-05
Epoch 3, Batch 1302/1432, Loss: 1.5325716958614066e-05
Epoch 3, Batch 1303/1432, Loss: 1.606704063306097e-05
Epoch 3, Batch 1304/1432, Loss: 1.5832349163247272e-05
Epoch 3, Batch 1305/1432, Loss: 1.5919891666271724e-05
Epoch 3, Batch 1306/1432, Loss: 1.57373578986153e-05
Epoch 3, Batch 1307/1432, Loss: 1.557903306093067e-05
Epoch 3, Batch 1308/1432, Loss: 1.5869600247242488e-05
Epoch 3, Batch 1309/1432, Loss: 1.566098944749683e-05
Epoch 3, Batch 1310/1432, Loss: 1.593479282746557e-05
Epoch 3, Batch 1311/1432, Loss: 1.538532342237886e-05
Epoch 3, Batch 1312/1432, Loss: 1.551384229969699e-05
Epoch 3, Batch 1313/1432, Loss: 1.6502890503033996e-05
Epoch 3, Batch 1314/1432, Loss: 1.6081941794254817e-05
Epoch 3, Batch 1315/1432, Loss: 1.6555039110244252e-05
Epoch 3, Batch 1316/1432, Loss: 1.5219548004097305e-05
Epoch 3, Batch 1317/

NameError: name 'train_features' is not defined