# testを読み込んでsubmission.csvを出力する


In [None]:
import os
import numpy as np
import torch
import pandas as pd

# 1) Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)


# 適切なモデル構造を最初に作って、その後に学習済みモデルを読み込む

In [None]:
from torch import nn

# 2) Model definition
class InversionNet70x70(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(5, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, kernel_size=3, padding=1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# Load pretrained model
model = InversionNet70x70().to(device)
model.load_state_dict(torch.load('first_model.pth', map_location=device))
model.eval()


# 学習済みモデルでテストデータに対して推論。その後csvデータを出力

In [None]:
# 3) Inference & CSV generation
test_dir = './data/waveform-inversion/test'
test_files = sorted([os.path.join(test_dir, f) for f in os.listdir(test_dir) if f.endswith('.npy')])

rows = []
for fp in test_files:
    arr = np.load(fp)                          # shape=(5,1000,70)
    inp = torch.from_numpy(arr).unsqueeze(0).float().to(device)  # (1,5,1000,70)
    with torch.no_grad():
        out = model(inp).squeeze(0).squeeze(0).cpu().numpy()     # (70,70)
    fid = os.path.splitext(os.path.basename(fp))[0]
    for y in range(out.shape[0]):
        row = {'oid_ypos': f'{fid}_y_{y}'}
        for j, xpos in enumerate(range(0, 70, 2)):
            row[f'x_{2*j+1}'] = out[y, xpos]
        rows.append(row)

df = pd.DataFrame(rows)
cols = ['oid_ypos'] + [f'x_{i}' for i in range(1, 70, 2)]
df = df[cols]

df.head()

#df.head(10)で確認できたら出力

#
# 
# df.to_csv('submission.csv', index=False)

