In [1]:
import torch
from quant_gans import Generator, Discriminator, load_financial_data, train_quant_gans
import matplotlib.pyplot as plt

In [None]:
# Load financial data
data = load_financial_data('AAPL.csv')

# Initialize the generator and the discriminator
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator = Generator(n_layers=8, n_channels=64, n_input=1, n_output=1, kernel_size=2, stride=1, dilation=2, padding=1, dropout=0.2).to(device)
discriminator = Discriminator(n_layers=8, n_channels=64, n_input=1, n_output=1, kernel_size=2, stride=1, dilation=2, padding=1, dropout=0.2).to(device)

# Train the QuantGANs model
train_quant_gans(data, generator, discriminator, num_epochs=200, batch_size=64, device=device)

#After training, you can use the generator to create synthetic financial time series data:
noise = torch.randn(batch_size, 1, 100).to(device)
generated_data = generator(noise)

In [None]:
noise = torch.randn(1, 1, 100).to(device)
generated_data = generator(noise).detach().cpu().numpy().flatten()

# Plot the generated data
plt.plot(generated_data)
plt.title("Generated Financial Time Series Data")
plt.xlabel("Time")
plt.ylabel("Normalized Adjusted Close Price")
plt.show()

##Saving and Loading Trained Models
You can save the trained generator and discriminator models to disk for later use. To do this, you can use the torch.save function. Here's an example:

# Save the trained generator and discriminator models
torch.save(generator.state_dict(), 'generator.pth')
torch.save(discriminator.state_dict(), 'discriminator.pth')

#To load the saved models back into memory, use the torch.load function:
# Load the saved generator and discriminator models
generator.load_state_dict(torch.load('generator.pth'))
discriminator.load_state_dict(torch.load('discriminator.pth'))