In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tqdm
import torch
from torch import nn 
import torch.nn.functional as F
from torch import optim
from torch.autograd import grad

from argparse import Namespace
from functools import reduce
import seaborn as sns
import pickle

In [None]:
def get_fft(y, sampling_rate):
  # Calculate the Fourier transform of f(x)
  fourier = np.fft.fft(y.squeeze())

  length = len(y)

  # Calculate the corresponding frequencies
  freq = np.fft.fftfreq(length, 1/sampling_rate)

  results = list(zip(freq, np.abs(fourier)))
  filtered = [(x, y) for (x, y) in results if x > 0]
  results = list(zip(*filtered))


  return results

In [None]:
def compare_solutions(t, y_true, y_pred, domain_len):
  start = round(256 * t * 100)
  end = start + 256

  fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(18, 5))

  ax0.plot(X[:, 0][start:end], y_true[start:end], "b-", label="Exact", linewidth=3)
  ax0.plot(X[:, 0][start:end], y_pred[start:end], "r--", label="Prediction", linewidth=3)
  ax0.legend(fontsize=12)
  ax0.set_title(f"Comparing Solutions at Time {t}")
  ax0.set_xlabel("x")
  ax0.set_ylabel("u(x, t)")

  true_freq, true_fourier = get_fft(y_true[start:end], 256/domain_len)
  pred_freq, pred_fourier = get_fft(y_pred[start:end], 256/domain_len)
  ax1.plot(true_freq, true_fourier, "b-", label="Exact", linewidth=3)
  ax1.plot(pred_freq, pred_fourier, "r--", label="Prediction", linewidth=3)
  ax1.legend(fontsize=12)
  ax1.set_title(f"Comparing Fourier Transforms of Solutions at Time {t}")
  ax1.set_xlabel("Frequency")
  ax1.set_ylabel("Amplitude")

  fourier_diff = abs(np.array(true_fourier) - np.array(pred_fourier))
  ax2.plot(true_freq, fourier_diff, "b-", linewidth=3)
  ax2.set_title(f"Difference in Fourier Transforms of Solutions at Time {t}")
  ax2.set_xlabel("Frequency")
  ax2.set_ylabel("Absolute Difference in Amplitude")
  
  plt.show()

In [None]:
with open('/home/ec2-user/project/CS229-Foundations-of-Deep-Learning/fourier_analyses/larger_spectrum/larger_spectrum_10.pkl') as f:
    larger_spectrum = pickle.load(f)