In [None]:
import zipfile

# 압축 파일 경로와 해제 경로 설정
zip_path = "/content/FER.zip"  # 업로드한 ZIP 파일 이름
extract_to = "/content/FER"    # 데이터셋을 추출할 폴더 이름

# 압축 해제
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_to)

print("압축 해제 완료!")

압축 해제 완료!


In [None]:
# Import basic packages first
import os
import sys
import numpy as np
import pandas as pd
from PIL import Image
import pywt

# Then import torch and its modules
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn

# Import other utilities
import gc
import traceback
from typing import Tuple, List

# Check PyTorch installation
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")

# Set random seeds
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def centralize_gradient(x, gc_axis=0):
    """Gradient Centralization"""
    size = x.size()
    if len(size) > 1:
        axis = (gc_axis,) if isinstance(gc_axis, int) else tuple(gc_axis)
        mean = x.mean(dim=axis, keepdim=True)
        x = x - mean
    return x

class GC_AdamW(optim.AdamW):
    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                # Apply gradient centralization
                if len(p.shape) > 1:
                    p.grad.data = centralize_gradient(p.grad.data)

        return super().step(closure)

def mixup_data(x, y, alpha=0.2):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

class WaveletTransform:
    def __init__(self, wavelet='db4', level=2):
        self.wavelet = wavelet
        self.level = level

    def apply_transform(self, img):
        try:
            # Ensure proper dimensions and type
            img_array = np.array(img, dtype=np.float32)

            # Apply wavelet transform
            coeffs = pywt.wavedec2(img_array, self.wavelet, level=self.level)

            # Get approximation and detail coefficients
            cA, *cD = coeffs

            # Initialize feature maps list with normalized approximation coefficients
            cA = (cA - np.mean(cA)) / (np.std(cA) + 1e-8)
            feature_maps = [cA]

            # Get the shape of approximation coefficients
            target_shape = cA.shape

            # Process detail coefficients
            for detail in cD:
                if isinstance(detail, tuple):
                    for d in detail:
                        # Resize detail coefficients to match approximation size
                        d_resized = resize_array(d, target_shape)
                        d_norm = (d_resized - np.mean(d_resized)) / (np.std(d_resized) + 1e-8)
                        feature_maps.append(d_norm)

            # Stack and ensure proper shape
            stacked = np.stack(feature_maps)
            return stacked.astype(np.float32)

        except Exception as e:
            print(f"Error in wavelet transform: {str(e)}")
            raise

def resize_array(arr, target_shape):
    """Resize numpy array to target shape using nearest neighbor interpolation"""
    h_ratio = target_shape[0] / arr.shape[0]
    w_ratio = target_shape[1] / arr.shape[1]

    h_indices = np.floor(np.arange(target_shape[0]) / h_ratio).astype(int)
    w_indices = np.floor(np.arange(target_shape[1]) / w_ratio).astype(int)

    return arr[h_indices][:, w_indices]

class CustomDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None, train=True):
        try:
            self.data = pd.read_csv(csv_file)
            self.img_dir = img_dir
            self.transform = transform
            self.train = train
            self.wavelet_transform = WaveletTransform()

            print(f"\nDataset size: {len(self.data)}")
            print(f"CSV file: {csv_file}")
            print(f"Image directory: {img_dir}")

            # Print column names for debugging
            print("\nCSV columns:")
            print(self.data.columns.tolist())

            if train:
                print("\nFirst few rows of data:")
                print(self.data.head())

        except Exception as e:
            print(f"Error initializing dataset: {str(e)}")
            print("CSV columns:", self.data.columns if hasattr(self, 'data') else "No data loaded")
            raise

    def __len__(self):
        """Return the size of the dataset"""
        return len(self.data) if hasattr(self, 'data') else 0

    def __getitem__(self, idx):
        try:
            # Get image filename from first column
            img_filename = str(self.data.iloc[idx, 0])
            img_path = os.path.join(self.img_dir, img_filename)

            # Verify file exists
            if not os.path.exists(img_path):
                raise FileNotFoundError(f"Image file not found: {img_path}")

            # Load image
            image = Image.open(img_path).convert('L')

            # Verify image size
            if image.size != (48, 48):
                print(f"Warning: Image {img_path} size is {image.size}, resizing to (48, 48)")

            # Apply transforms
            if self.transform:
                image = self.transform(image)

            # Get label (sum of emotion scores)
            emotion_scores = self.data.iloc[idx, 2:12].astype(float)
            label = emotion_scores.argmax()

            # Apply wavelet transform
            img_np = image.squeeze().numpy()

            # Debug information for first image
            if idx == 0:
                print(f"\nDebug - First image:")
                print(f"Image shape: {img_np.shape}")
                print(f"Image min/max: {img_np.min():.2f}/{img_np.max():.2f}")
                print(f"Label: {label}")
                print(f"Emotion scores: {emotion_scores.tolist()}")

            wavelet_features = self.wavelet_transform.apply_transform(img_np)

            if idx == 0:
                print(f"Wavelet features shape: {wavelet_features.shape}")
                print(f"Wavelet features min/max: {wavelet_features.min():.2f}/{wavelet_features.max():.2f}")

            wavelet_tensor = torch.from_numpy(wavelet_features).float()

            return wavelet_tensor, label

        except Exception as e:
            print(f"Error loading image {idx}: {str(e)}")
            print(f"Image path: {img_path}")
            print(f"Data row {idx}:", self.data.iloc[idx].tolist())
            raise

class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query = nn.Conv2d(in_channels, in_channels//8, 1)
        self.key = nn.Conv2d(in_channels, in_channels//8, 1)
        self.value = nn.Conv2d(in_channels, in_channels, 1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, C, H, W = x.size()

        q = self.query(x).view(batch_size, -1, H*W).permute(0, 2, 1)
        k = self.key(x).view(batch_size, -1, H*W)
        v = self.value(x).view(batch_size, -1, H*W)

        attention = torch.bmm(q, k)
        attention = F.softmax(attention, dim=2)

        out = torch.bmm(v, attention.permute(0, 2, 1))
        out = out.view(batch_size, C, H, W)

        return self.gamma * out + x

class KTN(nn.Module):
    def __init__(self, num_classes=10, input_channels=7):
        super(KTN, self).__init__()

        print(f"\nInitializing KTN with:")
        print(f"Number of classes: {num_classes}")
        print(f"Input channels: {input_channels}")

        # Increased initial channels and added dropout
        self.features = nn.Sequential(
            nn.Conv2d(input_channels, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.2),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.3),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.4),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        # Added residual connections and more attention layers
        self.transfer = nn.Sequential(
            nn.Conv2d(512, 768, kernel_size=3, padding=1),
            nn.BatchNorm2d(768),
            nn.ReLU(inplace=True),
            SelfAttention(768),
            nn.Dropout2d(0.4),

            nn.Conv2d(768, 768, kernel_size=3, padding=1),
            nn.BatchNorm2d(768),
            nn.ReLU(inplace=True),
            SelfAttention(768),
            nn.Dropout2d(0.4)
        )

        # Added more layers in classifier with dropout
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(768, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

        self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = self.transfer(x)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

def setup_cuda():
    try:
        if torch.cuda.is_available():
            device = torch.device("cuda")
            torch.cuda.empty_cache()

            # Reset device
            torch.cuda.reset_peak_memory_stats()
            torch.cuda.empty_cache()

            # Set device properties
            cudnn.benchmark = True
            cudnn.deterministic = True

            print(f"Using GPU: {torch.cuda.get_device_name(0)}")
            print(f"Memory allocated: {torch.cuda.memory_allocated(0)/(1024*1024):.2f} MB")
        else:
            device = torch.device("cpu")
            print("CUDA not available, using CPU")
    except Exception as e:
        print(f"CUDA initialization failed: {str(e)}")
        device = torch.device("cpu")
        print("Falling back to CPU")
    return device

def train_model(model, train_loader, criterion, optimizer, device, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (images, labels) in enumerate(train_loader):
        try:
            images = images.float().to(device)
            labels = labels.long().to(device)

            # Apply mixup
            images, labels_a, labels_b, lam = mixup_data(images, labels)

            optimizer.zero_grad()
            outputs = model(images)

            # Use mixup loss
            loss = mixup_criterion(criterion, outputs, labels_a, labels_b, lam)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            # Calculate accuracy with original labels
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += (lam * predicted.eq(labels_a).float() +
                       (1 - lam) * predicted.eq(labels_b).float()).sum().item()

            if batch_idx % 100 == 0:
                print(f'Epoch: {epoch+1}, Batch: {batch_idx}, '
                      f'Loss: {running_loss/(batch_idx+1):.4f}, '
                      f'Acc: {100.*correct/total:.2f}%')

        except Exception as e:
            print(f"Error in batch {batch_idx}: {str(e)}")
            print("Traceback:")
            traceback.print_exc()
            continue

    return 100. * correct / total

def validate_model(model, val_loader, criterion, device):
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            try:
                # Ensure data types are correct
                images = images.float()
                labels = labels.long()

                # Move to device
                images = images.to(device)
                labels = labels.to(device)

                # Forward pass
                outputs = model(images)
                loss = criterion(outputs, labels)

                # Statistics
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

            except Exception as e:
                print(f"Error in validation: {str(e)}")
                continue

    avg_loss = val_loss / len(val_loader)
    accuracy = 100. * correct / total
    return avg_loss, accuracy

def verify_submission(submission_path):
    """
    Verify the submission file format
    """
    try:
        df = pd.read_csv(submission_path)

        # Check columns
        expected_columns = ['ID', 'Prediction']
        if not all(col in df.columns for col in expected_columns):
            print("Error: Missing required columns")
            return False

        # Check ID range (now starting from 0)
        if df['ID'].min() != 0 or df['ID'].max() != len(df) - 1:
            print("Warning: ID range might be incorrect")
            print(f"Expected range: 0 to {len(df) - 1}")
            print(f"Actual range: {df['ID'].min()} to {df['ID'].max()}")

        # Check predictions range
        if df['Prediction'].min() < 0 or df['Prediction'].max() >= 10:
            print("Error: Predictions out of valid range (0-9)")
            return False

        print("\nSubmission file verification:")
        print(f"Number of predictions: {len(df)}")
        print(f"ID range: {df['ID'].min()} to {df['ID'].max()}")
        print(f"Prediction range: {df['Prediction'].min()} to {df['Prediction'].max()}")
        print("\nPrediction distribution:")
        print(df['Prediction'].value_counts().sort_index())

        return True

    except Exception as e:
        print(f"Error verifying submission: {str(e)}")
        return False

def test_model(model, test_dir, submission_path, device):
    """
    Test the model and create submission file
    """
    try:
        model.eval()

        # Create transforms for test images
        transform = transforms.Compose([
            transforms.Resize((48, 48)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485], std=[0.229])
        ])

        wavelet_transform = WaveletTransform()

        # Create submission dataframe
        results = {'ID': [], 'Prediction': []}

        # Process each image in test directory
        test_files = sorted(os.listdir(test_dir))

        print("\nStarting test prediction...")
        print(f"Found {len(test_files)} test images")

        with torch.no_grad():
            for idx, img_name in enumerate(test_files):  # Start index from 0
                try:
                    # Use index as ID
                    img_path = os.path.join(test_dir, img_name)

                    # Load and process image
                    image = Image.open(img_path).convert('L')

                    # Apply transforms
                    image = transform(image)

                    # Apply wavelet transform
                    img_np = image.squeeze().numpy()
                    wavelet_features = wavelet_transform.apply_transform(img_np)
                    wavelet_tensor = torch.from_numpy(wavelet_features).float()

                    # Add batch dimension
                    wavelet_tensor = wavelet_tensor.unsqueeze(0).to(device)

                    # Get prediction
                    outputs = model(wavelet_tensor)
                    _, predicted = outputs.max(1)

                    # Store results using index as ID (starting from 0)
                    results['ID'].append(idx)
                    results['Prediction'].append(predicted.item())

                    if (idx + 1) % 100 == 0:
                        print(f"Processed {idx + 1}/{len(test_files)} images")

                except Exception as e:
                    print(f"Error processing image {img_name}: {str(e)}")
                    print(f"Image path: {img_path}")
                    continue

        # Create and save submission file
        submission_df = pd.DataFrame(results)
        submission_df = submission_df.sort_values('ID')  # Sort by ID
        submission_df.to_csv(submission_path, index=False)
        print(f"\nSubmission file created at: {submission_path}")
        print(f"Total predictions: {len(submission_df)}")
        print("\nPrediction distribution:")
        print(submission_df['Prediction'].value_counts())

        # Verify submission file
        print("\nVerifying submission file...")
        if verify_submission(submission_path):
            print("Submission file verified successfully")
        else:
            print("Submission file verification failed")

    except Exception as e:
        print(f"Error in test_model: {str(e)}")
        traceback.print_exc()

def main():
    try:
        # Set CUDA environment variables
        os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
        os.environ['TORCH_USE_CUDA_DSA'] = '1'

        device = setup_cuda()

        # Data transforms with augmentation
        transform_train = transforms.Compose([
            transforms.Resize((48, 48)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485], std=[0.229])
        ])

        transform_test = transforms.Compose([
            transforms.Resize((48, 48)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485], std=[0.229])
        ])

        # Create datasets with different transforms
        print("\nInitializing datasets...")
        train_dataset = CustomDataset(
            csv_file='/content/FER/fer-competition/train_label.csv',
            img_dir='/content/FER/fer-competition/FER2013Train',
            transform=transform_train,
            train=True
        )

        val_dataset = CustomDataset(
            csv_file='/content/FER/fer-competition/valid_label.csv',
            img_dir='/content/FER/fer-competition/FER2013Valid',
            transform=transform_test,
            train=True
        )

        # Create data loaders with smaller batch size
        train_loader = DataLoader(
            train_dataset,
            batch_size=32,
            shuffle=True,
            num_workers=2,
            pin_memory=True if device.type == 'cuda' else False
        )

        val_loader = DataLoader(
            val_dataset,
            batch_size=32,
            shuffle=False,
            num_workers=2,
            pin_memory=True if device.type == 'cuda' else False
        )

        # Initialize model
        print("\nInitializing model...")
        model = KTN(num_classes=10, input_channels=7)
        model = model.to(device)

        # Print model summary
        print("\nModel structure:")
        print(model)
        print(f"\nModel device: {next(model.parameters()).device}")

        # Training mode
        if not os.path.exists('best_ktn_model.pth'):
            criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

            # Use AdamW with weight decay
            optimizer = GC_AdamW(
                model.parameters(),
                lr=0.001,
                weight_decay=0.01,
                betas=(0.9, 0.999)
            )

            # Cosine annealing scheduler with warm restarts
            scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
                optimizer,
                T_0=10,
                T_mult=2,
                eta_min=1e-6
            )

            # Early stopping
            patience = 10
            best_accuracy = 0.0
            no_improve_epochs = 0
            epochs = 20

            print("\nStarting training...")
            for epoch in range(epochs):
                # Training
                train_accuracy = train_model(model, train_loader, criterion, optimizer, device, epoch)

                # Validation
                val_loss, val_accuracy = validate_model(model, val_loader, criterion, device)

                # Learning rate scheduling
                scheduler.step()

                print(f'\nEpoch {epoch+1}/{epochs}:')
                print(f'Training Accuracy: {train_accuracy:.2f}%')
                print(f'Validation Loss: {val_loss:.4f}')
                print(f'Validation Accuracy: {val_accuracy:.2f}%')
                print(f'Learning Rate: {scheduler.get_last_lr()[0]:.6f}')

                # Save best model and check for early stopping
                if val_accuracy > best_accuracy:
                    best_accuracy = val_accuracy
                    torch.save(model.state_dict(), 'best_ktn_model.pth')
                    print(f'New best model saved with accuracy: {best_accuracy:.2f}%')
                    no_improve_epochs = 0
                else:
                    no_improve_epochs += 1
                    if no_improve_epochs >= patience:
                        print(f'\nEarly stopping after {patience} epochs without improvement')
                        break

                # Memory cleanup
                torch.cuda.empty_cache()
                gc.collect()

        # Testing mode
        print("\nLoading best model for testing...")
        try:
            # Initialize model
            model = KTN(num_classes=10, input_channels=7)

            # Load the state dict
            model.load_state_dict(torch.load('best_ktn_model.pth', map_location='cpu'))

            # Move model to device after loading
            model = model.to(device)
            print("Model loaded successfully")

        except Exception as e:
            print(f"Error loading model: {str(e)}")
            traceback.print_exc()
            return

        # Set model to evaluation mode
        model.eval()

        # Run testing and create submission file
        test_model(
            model=model,
            test_dir='/content/FER/fer-competition/FER2013Test',
            submission_path='/content/FER/fer-competition/submission.csv',
            device=device
        )

    except Exception as e:
        print(f"Error in main: {str(e)}")
        traceback.print_exc()
        return

if __name__ == '__main__':
    try:
        main()
    except KeyboardInterrupt:
        print("\nProcess interrupted by user")
    except Exception as e:
        print(f"Fatal error: {str(e)}")
        traceback.print_exc()
    finally:
        print("\nProgram finished")

PyTorch version: 2.5.1+cu124
CUDA available: True
CUDA version: 12.4
Using GPU: Tesla T4
Memory allocated: 16.25 MB

Initializing datasets...

Dataset size: 28557
CSV file: /content/FER/fer-competition/train_label.csv
Image directory: /content/FER/fer-competition/FER2013Train

CSV columns:
['fer0000000.png', '(0, 0, 48, 48)', '4', '0', '0.1', '1', '3', '2', '0.2', '0.3', '0.4', '0.5']

First few rows of data:
   fer0000000.png  (0, 0, 48, 48)  4  0  0.1  1  3  2  0.2  0.3  0.4  0.5
0  fer0000001.png  (0, 0, 48, 48)  6  0    1  1  0  0    0    0    2    0
1  fer0000002.png  (0, 0, 48, 48)  5  0    0  3  1  0    0    0    1    0
2  fer0000003.png  (0, 0, 48, 48)  4  0    0  4  1  0    0    0    1    0
3  fer0000004.png  (0, 0, 48, 48)  9  0    0  1  0  0    0    0    0    0
4  fer0000005.png  (0, 0, 48, 48)  6  0    0  1  0  0    1    1    1    0

Dataset size: 3578
CSV file: /content/FER/fer-competition/valid_label.csv
Image directory: /content/FER/fer-competition/FER2013Valid

CSV colu